In [2]:
import cv2
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import random
from sklearn.metrics import classification_report, confusion_matrix
import glob
import seaborn as sns
from torchvision import models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import Counter

class LungCancerDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        self.labels = {"adenocarcinoma": 0, "large.cell.carcinoma": 1, "normal": 2, "squamous.cell.carcinoma": 3}
        self.filepaths = []
        self.targets = []
        for label, idx in self.labels.items():
            paths = glob.glob(f"{root_dir}/{label}/*.png")
            self.filepaths.extend(paths)
            self.targets.extend([idx] * len(paths))

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, index):
        filepath = self.filepaths[index]
        img = cv2.imread(filepath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)
        label = self.targets[index]
        return img, label

class ModifiedResNet50(nn.Module):
    def __init__(self, num_classes=4):
        super(ModifiedResNet50, self).__init__()
        self.model = models.resnet50(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, num_classes)
        )
        self.batch_norm = nn.BatchNorm1d(num_ftrs)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.batch_norm(x)
        x = self.model.fc(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModifiedResNet50().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

def train_model(dataloader, model, criterion, optimizer, num_epochs):
    train_loss_history = []
    valid_loss_history = []
    valid_acc_history = []
    
    best_model_wts = None
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        
        for i, (inputs, labels) in enumerate(dataloader['train']):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(dataloader['train'].dataset)
        train_loss_history.append(epoch_loss)
        
        val_loss, val_acc = validate_model(dataloader['valid'], model, device, criterion)
        valid_loss_history.append(val_loss)
        valid_acc_history.append(val_acc)
        
        print(f'Epoch {epoch}/{num_epochs - 1}, Train Loss: {epoch_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_acc}')
        scheduler.step(val_loss)
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = model.state_dict().copy()
    
    print(f'Best Validation Accuracy: {best_acc}')
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), 'best_model.pth')
    
    return train_loss_history, valid_loss_history, valid_acc_history

def validate_model(dataloader, model, device, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    loss = running_loss / len(dataloader.dataset)
    accuracy = correct / total
    
    return loss, accuracy

def plot_metrics(train_loss, valid_loss, valid_acc):
    epochs = range(len(train_loss))
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss')
    plt.plot(epochs, valid_loss, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss over Epochs')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, valid_acc, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy over Epochs')
    
    plt.tight_layout()
    plt.show()

train_dataset = LungCancerDataset(root_dir='./data/train')
valid_dataset = LungCancerDataset(root_dir='./data/valid')

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

dataloader = {'train': train_dataloader, 'valid': valid_dataloader}

num_epochs = 100
train_loss, valid_loss, valid_acc = train_model(dataloader, model, criterion, optimizer, num_epochs)

plot_metrics(train_loss, valid_loss, valid_acc)

test_dataset = LungCancerDataset(root_dir='./data/test')
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

def test_model(test_dataloader, model, device):
    model.eval()
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = correct / total
    print(f'Accuracy of the model on the test images: {100 * accuracy}%')
    return all_labels, all_predictions

test_labels, test_predictions = test_model(test_dataloader, model, device)

cm = confusion_matrix(test_labels, test_predictions)
plt.figure(figsize=(16, 9))
ax = plt.subplot()
sns.heatmap(cm, annot=True, fmt='g', ax=ax)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(['adenocarcinoma', 'large.cell.carcinoma', 'normal', 'squamous.cell.carcinoma'])
ax.yaxis.set_ticklabels(['adenocarcinoma', 'large.cell.carcinoma', 'normal', 'squamous.cell.carcinoma'])
plt.show()

print(classification_report(test_labels, test_predictions, target_names=['adenocarcinoma', 'large.cell.carcinoma', 'normal', 'squamous.cell.carcinoma']))
