# Preparation

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from IPython.display import clear_output

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mchangli_824[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Define Network

In [4]:
class ConvNet(nn.Module):
    def __init__(self, kernel_size = 3, activation_fn = nn.ReLU()):
        super().__init__()

        self.max_pooling_2 = nn.MaxPool3d(kernel_size = 2)

        self.up_sampling_2 = nn.Upsample(scale_factor = 2)

        self.conv64_1_8 = nn.Sequential(
            nn.Conv3d(in_channels = 1, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv64_8_8 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv32_8_32 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv32_32_32 = nn.Sequential(
            nn.Conv3d(in_channels = 32, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv16_32_128 = nn.Sequential(
            nn.Conv3d(in_channels = 32, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv16_128_128 = nn.Sequential(
            nn.Conv3d(in_channels = 128, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv8_128_256 = nn.Sequential(
            nn.Conv3d(in_channels = 128, out_channels = 256, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 256),
            activation_fn
        )

        self.conv8_256_256 = nn.Sequential(
            nn.Conv3d(in_channels = 256, out_channels = 256, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 256),
            activation_fn
        )

        self.conv16_384_128 = nn.Sequential(
            nn.Conv3d(in_channels = 384, out_channels = 128, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 128),
            activation_fn
        )

        self.conv32_160_32 = nn.Sequential(
            nn.Conv3d(in_channels = 160, out_channels = 32, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 32),
            activation_fn
        )

        self.conv64_40_8 = nn.Sequential(
            nn.Conv3d(in_channels = 40, out_channels = 8, kernel_size = kernel_size, padding = 'same'),
            nn.BatchNorm3d(num_features = 8),
            activation_fn
        )

        self.conv64_8_1 = nn.Sequential(
            nn.Conv3d(in_channels = 8, out_channels = 1, kernel_size = kernel_size, padding = 'same'),
            activation_fn
        )

    def forward(self, x):
        x = self.conv64_1_8(x)
        x = self.conv64_8_8(x)
        feature_map_64 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv32_8_32(x)
        x = self.conv32_32_32(x)
        feature_map_32 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv16_32_128(x)
        x = self.conv16_128_128(x)
        feature_map_16 = x.detach()
        x = self.max_pooling_2(x)
        x = self.conv8_128_256(x)
        x = self.conv8_256_256(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_16, x), dim = 1)
        x = self.conv16_384_128(x)
        x = self.conv16_128_128(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_32, x), dim = 1)
        x = self.conv32_160_32(x)
        x = self.conv32_32_32(x)
        x = self.up_sampling_2(x)
        x = torch.cat((feature_map_64, x), dim = 1)
        x = self.conv64_40_8(x)
        x = self.conv64_8_1(x)
        return x

In [5]:
def train_epoch(model, training_loader, optimizer, loss_fn):
    cumulative_loss = 0.0
    for i, data in enumerate(training_loader):
        inputs, labels = data

        # Zero the gradients
        optimizer.zero_grad()

        # Make predictions
        outputs = model(inputs)

        # Compute loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        cumulative_loss += loss.item()
        
        wandb.log({'batch loss': loss.item()})
    return cumulative_loss / len(training_loader), cumulative_loss

In [6]:
def train(config, loss_fn):
    clear_output(wait = True)
    
    # initialize a wandb run
    wandb.init(config = config)

    # copy the config
    config = wandb.config
    
    print('config:', config)

    # get training loader
    training_loader = DataLoader(list(zip(X_train, y_train)), batch_size = config.batch_size, shuffle = True)

    # initialize model
    if config.activation_fn == 'ReLU':
        activation_fn = nn.ReLU()
    
    if config.activation_fn == 'Sigmoid':
        activation_fn = nn.Sigmoid()
    
    model = ConvNet(kernel_size = config.kernel_size, activation_fn = activation_fn).to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr = config.learning_rate, momentum = 0.9)

    for epoch in range(config.epochs_choice):
        avg_loss_per_batch, cumulative_loss = train_epoch(model, training_loader, optimizer, loss_fn)
        wandb.log({'avg_loss_per_batch': avg_loss_per_batch, 'cumulative_loss': cumulative_loss})
        print(f'Loss for epoch {epoch}: {cumulative_loss}')
    
    return model

In [7]:
def test(config, model, loss_fn):
    # copy the config
    config = wandb.config
    
    # get testing loader
    testing_loader = DataLoader(list(zip(X_test, y_test)), batch_size = config.batch_size, shuffle = True)
    
    testing_loss = 0.0
    for i, data in enumerate(testing_loader):
        inputs, labels = data
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        testing_loss += loss.item()
    return testing_loss / len(testing_loader), testing_loss

In [8]:
def evaluate(config = None):
    loss_fn = nn.MSELoss()
    model = train(config, loss_fn)
    avg_loss_per_batch_test, testing_loss = test(config, model, loss_fn)
    wandb.log({'avg_loss_per_batch_test': avg_loss_per_batch_test, 'testing_loss': testing_loss})

# Read in data

In [9]:
X = torch.rand((1280, 1, 64, 64, 64), device = device)
y = torch.rand((1280, 1, 64, 64, 64), device = device)
X_train, X_test = torch.utils.data.random_split(X, [0.8, 0.2])
y_train, y_test = torch.utils.data.random_split(y, [0.8, 0.2])

# Training settings

In [10]:
sweep_config = {
    'method': 'grid'
    }
metric = {
    'name': 'testing_loss',
    'goal': 'minimize'
    }
sweep_config['metric'] = metric
parameters_dict = {
    'kernel_size': {
        'values': [3, 4, 5]
    },
    'activation_fn': {
        'values': ['ReLU', 'Sigmoid']
    },
    'epochs_choice': {
          'values': [5, 10, 20]
    },
    'learning_rate': {
        'values': [1e-4, 1e-3, 1e-2]
    },
    'batch_size': {
        'values': [8, 4]
    },
}

sweep_config['parameters'] = parameters_dict

# Start

In [11]:
sweep_id = wandb.sweep(sweep_config, project = 'CNN_sweep')

Create sweep with ID: g2wh6a8o
Sweep URL: https://wandb.ai/changli_824/CNN_sweep/sweeps/g2wh6a8o


In [None]:
wandb.agent(sweep_id = sweep_id, function = evaluate)

config: {'activation_fn': 'ReLU', 'batch_size': 8, 'epochs_choice': 5, 'kernel_size': 3, 'learning_rate': 0.0001}
Loss for epoch 0: 22.592270016670227


Next Steps:<br>
Automate hyperparameter tuning, potentially with weights and biases<br>
Understand and figure out ways to limit GPU memory usage<br>
Try binary/short inputs<br>
Replace dummy data with real data<br>
Run with google drive data, get github data set up