In [1]:
import torch
import torch.nn as nn
from einops import rearrange
import time

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

pip install torchviz


Collecting torchviz
  Downloading torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->torchviz)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->torchviz)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->torchviz)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->torchviz)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->torchviz)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->torchviz)
  Downloading nvidia_cufft_cu12-11.2.1.3-py

In [4]:
pip install python-docx

Collecting python-docx
  Downloading python_docx-1.2.0-py3-none-any.whl.metadata (2.0 kB)
Downloading python_docx-1.2.0-py3-none-any.whl (252 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/253.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: python-docx
Successfully installed python-docx-1.2.0


In [3]:

import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from docx import Document
from docx.shared import Inches
import warnings

# Optional: Model visualization
try:
    from torchviz import make_dot
    TORCHVIZ_AVAILABLE = True
except ImportError:
    TORCHVIZ_AVAILABLE = False
    warnings.warn("torchviz not installed. Model visualization will be skipped.")

# Config
IMG_SIZE = 112
BATCH_SIZE = 16
EPOCHS = 10

# STEP 1: Load dataset paths and labels
image_dir = '/content/drive/MyDrive/PhDProject/colon_image_sets'
label_map = {'colon_n': 0, 'colon_aca': 1}

paths, labels = [], []
for folder, label in label_map.items():
    class_dir = os.path.join(image_dir, folder)
    for img in os.listdir(class_dir):
        if img.lower().endswith(('.jpg', '.png', '.jpeg')):
            paths.append(os.path.join(class_dir, img))
            labels.append(label)

df = pd.DataFrame({'filepaths': paths, 'labels': labels})
train_df, dummy_df = train_test_split(df, train_size=0.5, stratify=df['labels'], random_state=42)
test_df, valid_df = train_test_split(dummy_df, train_size=0.5, stratify=dummy_df['labels'], random_state=42)

# STEP 2: Dataset class
class ColonDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'filepaths']
        label = self.df.loc[idx, 'labels']
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# STEP 3: Transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# STEP 4: DataLoaders
train_loader = DataLoader(ColonDataset(train_df, transform), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(ColonDataset(valid_df, transform), batch_size=BATCH_SIZE)
test_loader = DataLoader(ColonDataset(test_df, transform), batch_size=BATCH_SIZE)

# STEP 5: MTANv3 with ResNet18
class MTANv3_ResNet18(nn.Module):
    def __init__(self):
        super(MTANv3_ResNet18, self).__init__()
        base_model = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        self.attention = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.Sigmoid()
        )
        self.classifier = nn.Linear(512, 1)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        attn = self.attention(x)
        x = x * attn
        x = self.classifier(x)
        return x

# STEP 6: Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MTANv3_ResNet18().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# STEP 7: Training loop
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(EPOCHS):
    model.train()
    running_loss, correct_train, total_train = 0.0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(images).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        preds = torch.round(torch.sigmoid(outputs))
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    avg_train_loss = running_loss / len(train_loader)
    train_acc = correct_train / total_train
    history['train_loss'].append(avg_train_loss)
    history['train_acc'].append(train_acc)

    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.float().to(device)
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = torch.round(torch.sigmoid(outputs))
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}")

# Save model
torch.save(model.state_dict(), 'mtanv3_resnet18_colon.pth')
pd.DataFrame(history).to_csv('history_resnet18.csv', index=False)

# Confusion Matrix & Classification Report
cm = confusion_matrix(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=['Normal', 'Adenocarcinoma'])
pd.DataFrame(cm).to_csv('confusion_matrix_resnet18.csv')
with open('classification_report_resnet18.txt', 'w') as f:
    f.write(report)

# STEP 8: Plots
plt.figure()
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('loss_curve_resnet18.png')
plt.close()

plt.figure()
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy Curve')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig('accuracy_curve_resnet18.png')
plt.close()

plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Normal', 'Adenocarcinoma'], yticklabels=['Normal', 'Adenocarcinoma'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.savefig('confusion_matrix_resnet18.png')
plt.close()

# STEP 9: Word Report
doc = Document()
doc.add_heading('Colon Cancer Classification Report (ResNet18 MTANv3)', 0)
doc.add_paragraph(report)
doc.add_picture('loss_curve_resnet18.png', width=Inches(5))
doc.add_picture('accuracy_curve_resnet18.png', width=Inches(5))
doc.add_picture('confusion_matrix_resnet18.png', width=Inches(4.5))
doc.save('ResNet18_Colon_Report.docx')

print("✅ All done! Report saved as 'ResNet18_Colon_Report.docx'")


Epoch 1/10 - Train Loss: 0.1092, Acc: 0.9655 | Val Loss: 0.0103, Acc: 0.9993
Epoch 2/10 - Train Loss: 0.0169, Acc: 0.9975 | Val Loss: 0.0168, Acc: 0.9964
Epoch 3/10 - Train Loss: 0.0089, Acc: 0.9982 | Val Loss: 0.0020, Acc: 1.0000
Epoch 4/10 - Train Loss: 0.0029, Acc: 0.9996 | Val Loss: 0.0007, Acc: 1.0000
Epoch 5/10 - Train Loss: 0.0103, Acc: 0.9960 | Val Loss: 0.0086, Acc: 0.9964
Epoch 6/10 - Train Loss: 0.0168, Acc: 0.9946 | Val Loss: 0.0117, Acc: 0.9949
Epoch 7/10 - Train Loss: 0.0022, Acc: 0.9993 | Val Loss: 0.0005, Acc: 1.0000
Epoch 8/10 - Train Loss: 0.0006, Acc: 1.0000 | Val Loss: 0.0005, Acc: 1.0000
Epoch 9/10 - Train Loss: 0.0035, Acc: 0.9989 | Val Loss: 0.0003, Acc: 1.0000
Epoch 10/10 - Train Loss: 0.0062, Acc: 0.9978 | Val Loss: 0.0070, Acc: 0.9956
✅ All done! Report saved as 'ResNet18_Colon_Report.docx'
