<a href="https://colab.research.google.com/github/nghess/607-sensory-coding/blob/main/3dCNN_v1_grid_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [65]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import tifffile as tiff
import pandas as pd
import numpy as np
import os

In [66]:
from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/My Drive/607_sensory_coding"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [67]:
# Navigate to the repository's directory
repo_path = os.path.join(path, '607-sensory-coding')
os.chdir(repo_path)

In [68]:
# Define dataset: Tiff stack with two labels
class TiffDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = tiff.imread(img_name)  # Load stack

        # Manually convert the image to a tensor and normalize
        image = torch.from_numpy(image).type(torch.FloatTensor) / 255.0

        # Add channel dim
        image = image.unsqueeze(0)

        #if self.transform:
        #    image = self.transform(image)

        # Ensure the image is a float tensor
        if not isinstance(image, torch.FloatTensor):
            image = image.type(torch.FloatTensor)

        # Convert labels to numerical format
        rotation_class = 1 if self.annotations.iloc[index, 1] == 'clockwise' else 0
        angle_class = int(self.annotations.iloc[index, 2])  # Maybe use int()

        # Combine the labels (e.g., using one-hot encoding for the input class)
        label = torch.tensor([rotation_class, angle_class], dtype=torch.long)

        return image, label

# Define a custom transform
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     lambda x: x.unsqueeze(0)  # Add a channel dimension
# ])

dataset = TiffDataset(csv_file='dataset/slice/labels_slice.csv', root_dir='dataset/')

# Determine the lengths for train and test sets
train_size = int(0.6 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=360, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=360, shuffle=False)

In [69]:
# Check the shape of the first few samples
for i in range(5):
    image, label = dataset.__getitem__(i)
    print(f"Image shape for sample {i}:", image.shape)

Image shape for sample 0: torch.Size([1, 5, 128, 128])
Image shape for sample 1: torch.Size([1, 5, 128, 128])
Image shape for sample 2: torch.Size([1, 5, 128, 128])
Image shape for sample 3: torch.Size([1, 5, 128, 128])
Image shape for sample 4: torch.Size([1, 5, 128, 128])


## Define Model

In [70]:
class Simple3DCNN(nn.Module):
    def __init__(self):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv3d(16, 16, 3, padding=1)
        self.fc1 = nn.Linear(16 * 32 * 1 * 32, 512)
        self.fc_rotation = nn.Linear(512, 2)  # 2 classes for rotation
        self.fc_orientation = nn.Linear(512, 36)  # 36 classes for orientation

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Apply 3D convolution and pooling
        x = self.pool(F.relu(self.conv2(x)))  # Apply 3D convolution and pooling
        x = x.view(-1, 16 * 32 * 1 * 32)  # Flatten tensor
        x = F.relu(self.fc1(x))  # Fully connected layer
        rotation_output = self.fc_rotation(x)
        orientation_output = self.fc_orientation(x)
        return rotation_output, orientation_output

model = Simple3DCNN()

## Train Model

In [None]:
epochs = 100

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Initialize a list to keep track of loss values
loss_values = []

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is active.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device (GPU or CPU)
model = model.to(device)

for epoch in range(epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):

        # Move to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        # Split the labels
        rotation_labels = labels[:, 0]
        input_labels = labels[:, 1]

        # Reset gradient
        optimizer.zero_grad()
        outputs = model(inputs)

        # Assuming your model's output is designed to handle both types of labels
        loss_rotation = criterion(outputs[0], rotation_labels)
        loss_input = criterion(outputs[1], input_labels)
        loss = loss_rotation + loss_input  # Combine losses, or handle as you see fit

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # At the end of each epoch, store the average loss
    epoch_loss = running_loss / len(train_loader)
    loss_values.append(epoch_loss)
    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss:.4f}')

CUDA is active.
Epoch [1/100], Loss: 4.2765
Epoch [2/100], Loss: 4.2766


### Save Model to File

In [None]:
print(model)

In [None]:
#torch.save(model, '/content/drive/My Drive/607_sensory_coding/test_model_2.pt')

## Test Model

In [None]:
# Test label prediction performance
def test_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct_rotation, correct_input, total = 0, 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            rotation_labels, input_labels = labels[:, 0], labels[:, 1]

            rotation_output, input_output = model(inputs)

            _, predicted_rotation = torch.max(rotation_output.data, 1)
            _, predicted_input = torch.max(input_output.data, 1)

            total += labels.size(0)
            correct_rotation += (predicted_rotation == rotation_labels).sum().item()
            correct_input += (predicted_input == input_labels).sum().item()

    print(f'Accuracy of the network on rotation prediction: {100 * correct_rotation / total}%')
    print(f'Accuracy of the network on input type prediction: {100 * correct_input / total}%')

test_model(model, test_loader, device)

## Visualize Trained Layers

In [None]:
# Function to pull a single sample input out for layer visualization
def get_sample_input(data_loader):

    for inputs, _ in data_loader:
        # Select one instance in the batch
        sample_input = inputs[0]
        return sample_input

sample_input = get_sample_input(test_loader)

In [None]:
def plot_feature_maps(model, input_tensor, selected_layers, ncols=4):

    model.eval()

    # Function to get the output of a layer
    def get_features_map(layer, input, output):
        feature_maps.append(output.cpu().data.numpy())

    # Attach hooks to the selected layers
    hooks = []
    for name, layer in model.named_modules():
        if name in selected_layers:
            hooks.append(layer.register_forward_hook(get_features_map))

    # Initialize the feature maps list and pass the input through the model
    feature_maps = []
    with torch.no_grad():
        model(input_tensor.unsqueeze(0).to(device))

    # Remove hooks (important to avoid memory leak)
    for hook in hooks:
        hook.remove()

    # Plotting
    for ii in range(1):
      for layer_maps in feature_maps:
          n_features = layer_maps.shape[1]
          nrows = n_features // ncols + int(n_features % ncols != 0)
          fig, axes = plt.subplots(nrows, ncols, figsize=(20, 2 * nrows))
          for i in range(n_features):
              row = i // ncols
              col = i % ncols
              ax = axes[row, col] if nrows > 1 else axes[col]
              print(layer_maps.shape)
              #layers_maps
              ax.imshow(layer_maps[ii, i, 0, :], cmap='gray')
              ax.axis('off')

          plt.show()

          # Save the figure
          #plt.savefig(f"grid_epoch_{epochs}_conv2[{ii}].png")  # Saves the plot as a PNG file
          #plt.close()  # Closes the current figure

# Pull out a layer and plot slices
selected_layers = ['conv2']
plot_feature_maps(model, sample_input, selected_layers, ncols=6)


In [None]:
plt.close('all')

In [None]:
# print(dataset[0])
# fig, axes = plt.subplots(nrows, ncols, figsize=(20, 2 * nrows))

ncols = 5
nrows = 7
fig, axes = plt.subplots(nrows, ncols, figsize=(20, 2 * nrows))

# Plot a 4x4 grid of images from dataset
for i in range(35):
    row = i // ncols
    col = i % ncols
    ax = axes[row, col] if nrows > 1 else axes[col]
    # add correspnding label from labels list
    #rotation_labels, input_labels = labels[:, 0], labels[:, 1]
    ax.set_title(f"{rotation_labels[0]}, {input_labels[0]}")
    img = dataset[i][0][:]
    img = img[:,0,:]
    #img = img.reshape(128, 128, 1).squeeze()
    ax.imshow(img, cmap='gray')
    ax.axis('off')