##Zip and upload utils from https://github.com/SHI-Labs/Compact-Transformers/tree/main/src



In [None]:
!unzip utils
!pip install timm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
from utils.transformers import TransformerClassifier
from utils.tokenizer import Tokenizer

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomPerspective(distortion_scale=0.05, p=0.5),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset
dataset_path = '/content/drive/MyDrive/Colonoscopy Images 3' # Make sure you update this path to your own personal path
full_dataset = ImageFolder(dataset_path, transform=train_transform)

# Split the dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2, pin_memory=True)


# Define the CCT model
class CCT(nn.Module):
    def __init__(self, img_size=256, embedding_dim=768, n_input_channels=3, n_conv_layers=1,
                 kernel_size=7, stride=2, padding=3, pooling_kernel_size=3, pooling_stride=2,
                 pooling_padding=1, dropout=0., attention_dropout=0.1, stochastic_depth=0.1,
                 num_layers=7, num_heads=6, mlp_ratio=4.0, num_classes=2,  # num_classes set to 2
                 positional_embedding='learnable', *args, **kwargs):
        super(CCT, self).__init__()

        self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
                                   n_output_channels=embedding_dim,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   pooling_kernel_size=pooling_kernel_size,
                                   pooling_stride=pooling_stride,
                                   pooling_padding=pooling_padding,
                                   max_pool=True,
                                   activation=nn.ReLU,
                                   n_conv_layers=n_conv_layers,
                                   conv_bias=False)

        self.classifier = TransformerClassifier(
            sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
                                                           height=img_size,
                                                           width=img_size),
            embedding_dim=embedding_dim,
            seq_pool=True,
            dropout=dropout,
            attention_dropout=attention_dropout,
            stochastic_depth=stochastic_depth,
            num_layers=num_layers,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            num_classes=num_classes,
            positional_embedding=positional_embedding
        )

    def forward(self, x):
        x = self.tokenizer(x)
        return self.classifier(x)

def cct_14_for_256x256(pretrained=False, progress=False, img_size=256, positional_embedding='learnable', num_classes=2, *args, **kwargs):
    # Adjust num_layers, num_heads, and embedding_dim as appropriate for your task and dataset
    # These parameters may need to be tuned based on the performance and resource constraints for 256x256 images
    return CCT(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=128, kernel_size=3, stride=2, padding=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, *args, **kwargs)


# Instantiate the model
model = CCT(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=128, kernel_size=3, stride=2, padding=1, img_size=256, positional_embedding='learnable', num_classes=2)

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Add a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
# Training loop
num_epochs = 20
best_val_loss = float('inf')
early_stopping_counter = 0
early_stopping_limit = 5

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    # Validation phase
    model.eval()
    val_loss = 0.0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.tolist())
            all_targets.extend(labels.tolist())

    # Calculate metrics
    val_loss = val_loss / len(val_loader)
    accuracy = 100 * accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions)
    recall = recall_score(all_targets, all_predictions)
    f1 = f1_score(all_targets, all_predictions)
    confusion = confusion_matrix(all_targets, all_predictions)

    # Print statistics
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}')
    print(f'Confusion Matrix:\n{confusion}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_limit:
            print("Early stopping triggered")
            break