In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from dataset import BrainTumorDataset

In [2]:
# Data Transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
train_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Training/',
                                 transform=data_transforms['train'])
val_dataset = BrainTumorDataset(root_dir='/Users/nasifsafwan/Downloads/ML/BrainTumorResearch/tumordata/Testing/'
                                 ,transform=data_transforms['val'])



In [4]:
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
import timm
import torch
import torch.nn as nn

class CombinedModel(nn.Module):
    def __init__(self, models, num_classes):
        super(CombinedModel, self).__init__()
        self.models = models
        
        # Get the feature sizes for each model
        feature_sizes = []
        for model in self.models:
            feature_sizes.append(model.num_features)
        
        # Combined feature size
        combined_feature_size = sum(feature_sizes)
        
        # Define the final fully connected layer
        self.fc = nn.Linear(combined_feature_size, num_classes)
        
    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(model(x))
        combined_output = torch.cat(outputs, dim=1)
        x = self.fc(combined_output)
        return x

# Initialize models
model1 = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=0)  # Output features only
model2 = timm.create_model('resnet50', pretrained=True, num_classes=0)  # Output features only
model3 = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0)  # Output features only

# Create the combined model
num_classes = 4  # Adjust as necessary
combined_model = CombinedModel([model1, model2, model3], num_classes=num_classes)

# Print the model to verify
print(combined_model)

model.safetensors:  31%|###       | 31.5M/102M [00:00<?, ?B/s]

CombinedModel(
  (fc): Linear(in_features=4352, out_features=4, bias=True)
)


In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=1e-4)

In [7]:
from tqdm import tqdm


num_epochs = 20
for epoch in range(num_epochs):
    combined_model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)
    
    for inputs, labels in train_bar:
        optimizer.zero_grad()
        outputs = combined_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        train_bar.set_postfix(loss=running_loss / ((train_bar.n + 1) * train_loader.batch_size))

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    
    # Validation step
    combined_model.eval()
    running_corrects = 0
    
    val_bar = tqdm(val_loader, desc="Validation", ncols=100)
    with torch.no_grad():
        for inputs, labels in val_bar:
            outputs = combined_model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
    
    epoch_acc = running_corrects.double() / len(val_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {epoch_acc:.4f}')
    
    # Optional: Step the scheduler
    #scheduler.step()

Epoch 1/20:  48%|███████████████████                     | 43/90 [53:15<58:12, 74.31s/it, loss=1.25]


KeyboardInterrupt: 