In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from IPython.display import clear_output

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
logging = True

In [4]:
if logging:
    import wandb
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mchangli_824[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
class ConvNet(nn.Module):
    def __init__(self, kernel_size = 3, activation_fn = nn.ReLU()):
        super().__init__()

        self.max_pooling_2 = nn.MaxPool3d(kernel_size = 2)

        self.up_sampling_2 = nn.Upsample(scale_factor = 2)

        self.conv64_1_8 = nn.Sequential(
            nn.Conv3d(in_channels = 1, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv64_8_8 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv32_8_32 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv32_32_32 = nn.Sequential(
            nn.Conv3d(in_channels = 32, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv16_32_128 = nn.Sequential(
            nn.Conv3d(in_channels = 32, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv16_128_128 = nn.Sequential(
            nn.Conv3d(in_channels = 128, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv8_128_256 = nn.Sequential(
            nn.Conv3d(in_channels = 128, out_channels = 256, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 256),
            activation_fn
        )

        self.conv8_256_256 = nn.Sequential(
            nn.Conv3d(in_channels = 256, out_channels = 256, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 256),
            activation_fn
        )

        self.conv16_384_128 = nn.Sequential(
            nn.Conv3d(in_channels = 384, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv32_160_32 = nn.Sequential(
            nn.Conv3d(in_channels = 160, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv64_40_8 = nn.Sequential(
            nn.Conv3d(in_channels = 40, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv64_8_1 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 1, kernel_size = kernel_size, padding = 'same'),
            activation_fn
        )

    def forward(self, x):
        x = self.conv64_1_8(x)
        x = self.conv64_8_8(x)
        feature_map_64 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv32_8_32(x)
        x = self.conv32_32_32(x)
        feature_map_32 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv16_32_128(x)
        x = self.conv16_128_128(x)
        feature_map_16 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv8_128_256(x)
        x = self.conv8_256_256(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_16, x), dim = 1)
        x = self.conv16_384_128(x)
        x = self.conv16_128_128(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_32, x), dim = 1)
        x = self.conv32_160_32(x)
        x = self.conv32_32_32(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_64, x), dim = 1)
        x = self.conv64_40_8(x)
        x = self.conv64_8_1(x)
        return x

In [6]:
def train(epochs, model, training_loader, lr):
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = 0.9)
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(training_loader):
            inputs, labels = data

            # Zero the gradients
            optimizer.zero_grad()

            # Make predictions
            outputs = model(inputs)

            # Compute loss and its gradients
            loss = loss_fn(outputs, labels)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            running_loss += loss.item()
        print(f'Loss for epoch {epoch}: {running_loss}')

In [7]:
def evaluate_loss(model, data_loader):
    loss_fn = nn.MSELoss()
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        inputs, labels = data
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        running_loss += loss.item()
    return running_loss

In [8]:
X = torch.rand((1280, 1, 64, 64, 64), device = device)
y = torch.rand((1280, 1, 64, 64, 64), device = device)

In [9]:
kernel_sizes = [3, 4, 5]
activation_fns = [nn.ReLU(), nn.Sigmoid()]
epochs_choices = [5, 10, 20]
learning_rates = [1e-4, 1e-3, 1e-2]

In [10]:
for kernel_size in kernel_sizes:
    for activation_fn in activation_fns:
        for epochs in epochs_choices:
            for learning_rate in learning_rates:
                clear_output(wait=True)
                
                config = {
                    'kernel_size': kernel_size,
                    'activation_fn': activation_fn,
                    'epochs': epochs,
                    'learning_rate': learning_rate
                }

                if logging:
                    # initialize a wandb run
                    wandb.init(
                        project = 'CNN_first_test',
                        config = config,
                        name = str(config)
                    )

                    # copy the config
                    config = wandb.config

                # get training loader
                training_loader = DataLoader(list(zip(X, y)), batch_size = 8, shuffle = True)

                # initialize model
                model = ConvNet(kernel_size = kernel_size, activation_fn = activation_fn).to(device)

                train(epochs, model, training_loader, learning_rate)
                
                loss = evaluate_loss(model, training_loader)

                print(f'{config}: {loss}')

                metrics = {
                    'training_loss': loss
                }
                
                if logging:
                    wandb.log(metrics)

                    wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

Loss for epoch 0: 13.362125054001808
Loss for epoch 1: 13.340454988181591
Loss for epoch 2: 13.338420696556568
Loss for epoch 3: 13.33737763017416
Loss for epoch 4: 13.336708277463913
Loss for epoch 5: 13.336253270506859
Loss for epoch 6: 13.335901737213135
Loss for epoch 7: 13.335637055337429
Loss for epoch 8: 13.335419490933418
Loss for epoch 9: 13.33523490279913
Loss for epoch 10: 13.335081852972507
Loss for epoch 11: 13.334949217736721
Loss for epoch 12: 13.334832720458508
Loss for epoch 13: 13.334734372794628
Loss for epoch 14: 13.334646098315716
Loss for epoch 15: 13.334568418562412
Loss for epoch 16: 13.33449475467205
Loss for epoch 17: 13.334428615868092
Loss for epoch 18: 13.334374837577343
Loss for epoch 19: 13.334315426647663
{'kernel_size': 5, 'activation_fn': 'Sigmoid()', 'epochs': 20, 'learning_rate': 0.01}: 13.334305845201015


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
training_loss,▁

0,1
training_loss,13.33431


Next Steps:<br>
Automate hyperparameter tuning, potentially with weights and biases<br>
Understand and figure out ways to limit GPU memory usage<br>
Try binary/short inputs<br>
Replace dummy data with real data<br>
Run with google drive data, get github data set up