<h3>Imports</h3>

In [47]:
from PIL import Image
from collections import Counter
import datetime, os
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torchvision.transforms as transforms

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DataLoader = torch.utils.data.DataLoader

SEED = 1
torch.manual_seed(SEED)

<torch._C.Generator at 0x107721d50>

<h3>Set-up classes and mappings</h3>

In [48]:
classes = ["ACK", "BCC", "MEL", "NEV", "SCC", "SEK"]
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}

<h3>Pre-processing: cropping and lowering resolution, per paper.</h3>

In [49]:
def crop_center(img, crop_ratio):
    width, height = img.size
    new_size = int(crop_ratio * min(width, height))
    left = (width - new_size) // 2
    top = (height - new_size) // 2
    right = left + new_size
    bottom = top + new_size
    return img.crop((left, top, right, bottom))


transform = transforms.Compose([
    transforms.Lambda(lambda img: crop_center(img, 0.8)),
    transforms.Resize((224, 224)), # 224x224 is a common choice for RESNET-18, I'm told.. 
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225]),
])

<h3>Dataset Class</h3>

In [83]:
class PAD_UFES_Dataset(Dataset):
    def __init__(self, img_dir, label_dict, transform=None):
        self.img_dir = img_dir
        self.label_dict = label_dict
        self.transform = transform
        self.image_files = [f for f in os.listdir(img_dir) 
                            if f.endswith('.png') and f in label_dict]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        label_name = self.label_dict[img_name]
        label = class_to_idx[label_name]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

<h3>Establish Paths</h3>

In [84]:
log_dir_path = '/Users/vedansh/Desktop/230PRJ/skin-lesion-classifier/runs'

data_path = os.path.expanduser('~/Desktop/230PRJ/PAD-UFES-20/')
metadata_path = os.path.join(data_path, 'metadata.csv')
images_path = os.path.join(data_path, 'images')

<h3>Load Labels, Initialize and Split dataset</h3>

In [85]:
metadata = pd.read_csv(metadata_path)

label_dict = dict(zip(metadata['img_id'], metadata['diagnostic']))
label_dict = {f"{key}": value for key, value in label_dict.items()}

dataset = PAD_UFES_Dataset(img_dir=images_path, label_dict=label_dict, transform=transform)

dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Sanity check
assert train_size + test_size + val_size == dataset_size == 2298

<h3>Get class weights for later: potentially weighted softmax, or re-sampling, etc.</h3>

In [88]:
train_labels = [label for _, label in train_dataset]
label_counts = Counter(train_labels)
total_samples = sum(label_counts.values())
class_weights = [total_samples / label_counts[i] for i in range(len(classes))]
class_weights = torch.FloatTensor(class_weights).to(DEVICE)

<h3>Set-up Data Loaders</h3>

In [98]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

<h3>Attempt to use RESNET-18</h3>

In [99]:
assert len(classes) == 6

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))

# Unfreeze just the last CONV layer and last FC layer, and tweak it for 6 outputs
for param in model.parameters():
    param.requires_grad = False
for param in model.layer4.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True

model = model.to(DEVICE)

<h3>Loss function, optimizer</h3>

In [100]:
criterion = nn.CrossEntropyLoss(weight=class_weights) # Weighted CEL
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # Keeping this the same as the paper.
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # LR decay.

<h3>Training time</h3>

In [102]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in train_loader:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    scheduler.step()
    
    epoch_loss = running_loss / train_size
    epoch_acc = running_corrects.double() / train_size
    
    # Validation
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0
    
    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(DEVICE)
            val_labels = val_labels.to(DEVICE)
            
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)
            
            _, val_preds = torch.max(val_outputs, 1)
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_corrects += torch.sum(val_preds == val_labels.data)
    
    val_loss = val_running_loss / val_size
    val_acc = val_running_corrects.double() / val_size
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 1/5
Train Loss: 1.6622 Acc: 0.4314
Val Loss: 1.5077 Acc: 0.4061
Epoch 2/5
Train Loss: 0.9383 Acc: 0.6464
Val Loss: 1.1442 Acc: 0.4803
Epoch 3/5
Train Loss: 0.5977 Acc: 0.7519
Val Loss: 1.6003 Acc: 0.6681
Epoch 4/5
Train Loss: 0.3357 Acc: 0.8618
Val Loss: 1.7841 Acc: 0.6507
Epoch 5/5
Train Loss: 0.2370 Acc: 0.9108
Val Loss: 1.7126 Acc: 0.6070


<h3>Testing</h3>

In [104]:
model.eval()
test_running_corrects = 0

with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs = test_inputs.to(DEVICE)
        test_labels = test_labels.to(DEVICE)
        
        test_outputs = model(test_inputs)
        _, test_preds = torch.max(test_outputs, 1)
        test_running_corrects += torch.sum(test_preds == test_labels.data)

test_acc = test_running_corrects.double() / test_size
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.5758


<h3>Save RESNET-18</h3>

In [None]:
#saveity save save

<h3>Custom (Simple) CNN Model</h3>

In [106]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # 3 CONV
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # 1 MAXPOOL
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 2 FC
        self.fc1 = nn.Linear(128 * 28 * 28, 256)
        self.fc2 = nn.Linear(256, len(classes))
        # 1 DO
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # OP Dims: [batch_size, 32, 112, 112]
        x = self.pool(F.relu(self.conv2(x)))  # OP Dims: [batch_size, 64, 56, 56]
        x = self.pool(F.relu(self.conv3(x)))  # OP Dims: [batch_size, 128, 28, 28]
        x = x.view(-1, 128 * 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

<h3>Set this up now.</h3>

In [None]:
model = SimpleCNN(num_classes=len(classes))
assert len(classes) == 6, "you're fucked."
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
optimizer = optim.Adam(model.parameters(), lr=0.001)

<h3>Train SIMPLE CNN</h3>

In [110]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in train_loader:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        optimizer.zero_grad()

        # Spring forward
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # Fall back
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / train_size
    epoch_acc = running_corrects.double() / train_size
    
    # Validation
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0
    
    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(DEVICE)
            val_labels = val_labels.to(DEVICE)
            
            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)
            
            _, val_preds = torch.max(val_outputs, 1)
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_corrects += torch.sum(val_preds == val_labels.data)
    
    val_loss = val_running_loss / val_size
    val_acc = val_running_corrects.double() / val_size
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 1/5
Train Loss: 0.1054 Acc: 0.9532
Val Loss: 1.9611 Acc: 0.6725
Epoch 2/5
Train Loss: 0.0321 Acc: 0.9880
Val Loss: 1.7768 Acc: 0.6681
Epoch 3/5
Train Loss: 0.0454 Acc: 0.9793
Val Loss: 1.9663 Acc: 0.6070
Epoch 4/5
Train Loss: 0.0524 Acc: 0.9761
Val Loss: 2.0009 Acc: 0.6900
Epoch 5/5
Train Loss: 0.0621 Acc: 0.9739
Val Loss: 1.9509 Acc: 0.6856


<h3>Test SIMPLE CNN</h3>

In [112]:
model.eval()
test_running_corrects = 0

with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs = test_inputs.to(DEVICE)
        test_labels = test_labels.to(DEVICE)
        
        test_outputs = model(test_inputs)
        _, test_preds = torch.max(test_outputs, 1)
        test_running_corrects += torch.sum(test_preds == test_labels.data)

test_acc = test_running_corrects.double() / test_size
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.6450


<h3>Save SIMPLE CNN</h3>

In [None]:
# Save l8r