In [12]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from urllib.request import urlopen
from PIL import Image
from torchvision import transforms
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset, random_split
import pandas as pd


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                         std=[0.229, 0.224, 0.225]),
])

# Load data

In [42]:
# Define custom dataset

class BreakHisDataset(Dataset): # Subclass Dataset, which is required for using DataLoader
    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'filepath']
        label = self.df.loc[idx, 'label']
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)
    
dataset_train = BreakHisDataset(csv_path="../data/train.csv", transform=transform) # Load the data
dataset_test = BreakHisDataset(csv_path="../data/test.csv", transform=transform)

# Create iterable data loaders

train_loader = DataLoader(dataset_train, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=16, shuffle=False) # shuffle=False for consistent evaluation

In [44]:
df = pd.read_csv("../data/test.csv")
df['filepath'][0]

'../data/versions/4/BreaKHis_v1/BreaKHis_v1/histology_slides/breast/malignant/SOB/papillary_carcinoma/SOB_M_PC_15-190EF/40X/SOB_M_PC-15-190EF-40-019.png'

In [31]:
# Load the pre-trained model

model = timm.create_model('convnextv2_atto.fcmae', pretrained=True, num_classes=2)

# Freeze all layers (can then unfreeze sequentially)
for param in model.parameters():
    param.requires_grad = False

# Unfreeze classifier
for param in model.head.parameters():
    param.requires_grad = True
    
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.head.parameters(), lr=1e-3)

train_accuracies = []
val_accuracies = []
train_losses = []
val_losses = []

In [45]:
n_epochs = 10

for epoch in range(n_epochs):
    model.train()
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        optimiser.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimiser.step()
        
        running_loss += loss.item() * images.size(0) # Accumulate loss
        
        _, preds = torch.max(outputs, 1) # Take the maximum one as the class with the highest predicted probability => predicted class
        correct += (preds == labels).sum().item() # Number of correct predictions 
        total += labels.size(0) # adding correct predictions
        
    # Update training losses / accuracies
    
    epoch_train_loss = running_loss / total
    train_losses.append(epoch_train_loss)
    
    epoch_train_acc = correct / total
    train_accuracies.append(epoch_train_acc)
    
    # Evaluation phase
    
    model.eval()
    correct_val = 0
    total_val = 0
    running_val_loss = 0.0
    
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)
    
    epoch_val_acc = correct_val / total_val
    epoch_val_loss = running_val_loss / total_val
    val_accuracies.append(epoch_val_acc)
    val_losses.append(epoch_val_loss)
    
    print(f"Epoch [{epoch+1}/{n_epochs}], Train loss: {epoch_train_loss}, Test loss: {epoch_val_loss} Train accuracy: {epoch_train_acc}, Test accuracy: {epoch_val_acc}")


TypeError: new(): invalid data type 'str'