In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.fc1 = None
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        
        if self.fc1 is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            self.fc1 = nn.Linear(flattened_size, 256).to(x.device)
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


In [3]:
import torch.optim as optim
from torch.utils.data import DataLoader


def train_network(network, dataloader, device, num_epochs=10, learning_rate=0.001, use_metric_learning=False):
    # Criterion for cross-entropy loss
    criterion = nn.CrossEntropyLoss()
    
    # Criterion for metric learning (e.g., TripletMarginLoss)
    metric_criterion = nn.TripletMarginLoss(margin=1.0, p=2)
    
    optimizer = optim.Adam(network.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for data in dataloader:
            if use_metric_learning:
                
                anchor, positive, negative = data
                anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
                
                optimizer.zero_grad()
                
                anchor_out = network(anchor)
                positive_out = network(positive)
                negative_out = network(negative)
                
                loss = metric_criterion(anchor_out, positive_out, negative_out)
            else:
                
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                outputs = network(images)
                
                loss = criterion(outputs, labels)
            
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader)}')

    print('Training completed.')




In [4]:
import Utils.lung_cancer_data 
import numpy as np
from Utils.lung_cancer_data import get_dataloader

dataloader_metric = get_dataloader("Data/Image/IQ-OTHNCCD", 32, True, metric_learning=True)
dataloader = get_dataloader("Data/Image/IQ-OTHNCCD", 32, True, metric_learning=False)


if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')  # For M1 Macs 
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

SimpleCNN = SimpleCNN().to(device)
train_network(SimpleCNN, dataloader_metric, device, use_metric_learning=True)
train_network(SimpleCNN, dataloader, device, use_metric_learning=False)

 

Using device: mps
Epoch [1/10], Loss: 0.7814665990216392
Epoch [2/10], Loss: 0.359001539008958
Epoch [3/10], Loss: 0.1984188453427383
Epoch [4/10], Loss: 0.12451365324003356
Epoch [5/10], Loss: 0.12536260166338511
Epoch [6/10], Loss: 0.088811751135758
Epoch [7/10], Loss: 0.02833239617092269
Epoch [8/10], Loss: 0.04446779383080346
Epoch [9/10], Loss: 0.07832080615418298
Epoch [10/10], Loss: 0.05452170755181994
Training completed.
Epoch [1/10], Loss: 15.423603243487221
Epoch [2/10], Loss: 0.37230201278414043
Epoch [3/10], Loss: 0.0621061984449625
Epoch [4/10], Loss: 0.028501795155794492
Epoch [5/10], Loss: 0.012664317612403206
Epoch [6/10], Loss: 0.014058543613646179
Epoch [7/10], Loss: 0.006381686475859689
Epoch [8/10], Loss: 0.011747867510920124
Epoch [9/10], Loss: 0.009506057124651437
Epoch [10/10], Loss: 0.016538717606038388
Training completed.
