In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision
import pandas as pd
from PIL import Image
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import fbeta_score
from itertools import cycle
import torch.nn.functional as F

In [2]:
transforms = {
    # ResNet
    'ResNet': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),

    # EfficientNet_b1
    'EfficientNet': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(240),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}

In [3]:
df = pd.read_csv("data/train_classes.csv")
df

Unnamed: 0,image_name,tags
0,train_0,haze primary
1,train_1,agriculture clear primary water
2,train_2,clear primary
3,train_3,clear primary
4,train_4,agriculture clear habitation primary road
...,...,...
40474,train_40474,clear primary
40475,train_40475,cloudy
40476,train_40476,agriculture clear primary
40477,train_40477,agriculture clear primary road


In [4]:
all_tags = set()
for tags in df['tags'].str.split():
    all_tags.update(tags)

In [5]:
tag_to_idx = {tag: idx for idx, tag in enumerate(sorted(all_tags))}
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}
print(tag_to_idx)
print(len(tag_to_idx))

{'agriculture': 0, 'artisinal_mine': 1, 'bare_ground': 2, 'blooming': 3, 'blow_down': 4, 'clear': 5, 'cloudy': 6, 'conventional_mine': 7, 'cultivation': 8, 'habitation': 9, 'haze': 10, 'partly_cloudy': 11, 'primary': 12, 'road': 13, 'selective_logging': 14, 'slash_burn': 15, 'water': 16}
17


In [6]:
class MultiLabelImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        tags = self.df.iloc[idx, 1].split()
        labels = torch.zeros(len(tag_to_idx))
        for tag in tags:
            labels[tag_to_idx[tag]] = 1
        
        return image, labels

In [7]:
resnet_dataset = MultiLabelImageDataset(csv_file="data/train_classes.csv", img_dir="data/train-jpg", transform=transforms['ResNet'])
effnet_dataset = MultiLabelImageDataset(csv_file="data/train_classes.csv", img_dir="data/train-jpg", transform=transforms['ResNet'])
resnet2_dataset = MultiLabelImageDataset(csv_file="data/train_classes.csv", img_dir="data/train-jpg", transform=transforms['ResNet'])

In [8]:
def split_dataset(dataset, train_idx, test_idx):
    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)
    return train_dataset, test_dataset

dataset_length = len(resnet_dataset)
train_idx, test_idx = train_test_split(
    list(range(dataset_length)), 
    test_size=0.1, 
    # random_state=42
)

resnet_train, resnet_test = split_dataset(resnet_dataset, train_idx, test_idx)
effnet_train, effnet_test = split_dataset(effnet_dataset, train_idx, test_idx)
resnet2_train, resnet2_test = split_dataset(resnet2_dataset, train_idx, test_idx)

In [9]:
batch_size = 32

resnet_train_loader = DataLoader(resnet_train, batch_size=batch_size, shuffle=True)
resnet_test_loader = DataLoader(resnet_test, batch_size=batch_size, shuffle=False)

effnet_train_loader = DataLoader(effnet_train, batch_size=batch_size, shuffle=True)
effnet_test_loader = DataLoader(effnet_test, batch_size=batch_size, shuffle=False)

resnet2_train_loader = DataLoader(resnet2_train, batch_size=batch_size, shuffle=True)
resnet2_test_loader = DataLoader(resnet2_test, batch_size=batch_size, shuffle=False)

In [10]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [11]:
import torchvision.models as models
from torch import nn

num_classes = 17

def ResNetClassifier(num_classes):  
    # load a pre-trained model
    model_ft = models.resnet50(weights='DEFAULT')
    num_ftrs = model_ft.fc.in_features
    
    # freeze all the parameters in the network except the final layer
    # for param in model_ft.parameters():
    #     param.requires_grad = False
    
    # replace the last fully connected layer
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    return model_ft

def EfficientNetClassifier(num_classes):
    # load a pre-trained model
    model_ft = models.efficientnet_b1(weights='DEFAULT')
    num_ftrs = model_ft.classifier[1].in_features
    
    # freeze all the parameters in the network except the final layer
    # for param in model_ft.parameters():
    #     param.requires_grad = False
    
    # replace the last fully connected layer
    model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
    return model_ft

In [12]:
class EnsembleModel(nn.Module):
    def __init__(self, num_classes, ensemble_type='weighted'):
        super(EnsembleModel, self).__init__()
        
        # initialize individual models
        self.resnet = ResNetClassifier(num_classes)
        self.effnet = ResNetClassifier(num_classes)
        self.resnet2 = ResNetClassifier(num_classes)

        self.resnet.load_state_dict(torch.load("LP_oversampled_ResNet50_0.5epochs_1e-4_ADAM.pth", weights_only=True))
        self.effnet.load_state_dict(torch.load("ML_oversampled_augmented_ResNet50_5epochs_1e-4_ADAM.pth", weights_only=True))
        self.resnet2.load_state_dict(torch.load("LP_oversampled_ResNet50_0.5epochs_1e-4_ADAM.pth", weights_only=True))
        
        # ensemble type
        self.ensemble_type = ensemble_type
        
        # weighted averaging - initialised to 1/3 each
        if ensemble_type == 'weighted':
            self.weights = nn.Parameter(torch.ones(3) / 3)

        # parameters for shepard's rule
        self.a = 1.0
        self.b = 1.0
    
    def forward(self, x):
        # get predictions from each model
        resnet_out = self.resnet(x)
        effnet_out = self.effnet(x)
        resnet2_out = self.resnet2(x)
        
        # ensemble strategies
        if self.ensemble_type == 'voting':
            # soft voting - average of predictions
            return (resnet_out + effnet_out + resnet2_out) / 3
        
        elif self.ensemble_type == 'weighted':
            # weighted average of predictions
            # normalize weights to sum to 1
            normalized_weights = nn.functional.softmax(self.weights, dim=0)
            
            weighted_out = (
                normalized_weights[0] * resnet_out + 
                normalized_weights[1] * effnet_out + 
                normalized_weights[2] * resnet2_out
            )
            return weighted_out

        elif self.ensemble_type == 'dudani':
            # dudani's rule weights
            distances = torch.stack([
                -torch.max(torch.sigmoid(resnet_out), dim=1)[0],
                -torch.max(torch.sigmoid(effnet_out), dim=1)[0],
                -torch.max(torch.sigmoid(resnet2_out), dim=1)[0]
            ], dim=1) 
            
            d1, _ = torch.min(distances, dim=1, keepdim=True)
            dq, _ = torch.max(distances, dim=1, keepdim=True)
            
            diff = dq - d1
            diff[diff == 0] = 1e-10
            
            dudani_weights = (dq - distances) / diff
            dudani_weights = dudani_weights / dudani_weights.sum(dim=1, keepdim=True)

            weighted_out = (
                dudani_weights[:, 0].unsqueeze(1) * resnet_out +
                dudani_weights[:, 1].unsqueeze(1) * effnet_out +
                dudani_weights[:, 2].unsqueeze(1) * resnet2_out
            )
            return weighted_out

        elif self.ensemble_type == 'shepard':
            distances = torch.stack([
                -torch.max(torch.sigmoid(resnet_out), dim=1)[0],
                -torch.max(torch.sigmoid(effnet_out), dim=1)[0],
                -torch.max(torch.sigmoid(resnet2_out), dim=1)[0]
            ], dim=1)

            shepard_weights = torch.exp(-self.a * torch.abs(distances) ** self.b)
            shepard_weights = shepard_weights / shepard_weights.sum(dim=1, keepdim=True)

            weighted_out = (
                shepard_weights[:, 0].unsqueeze(1) * resnet_out +
                shepard_weights[:, 1].unsqueeze(1) * effnet_out +
                shepard_weights[:, 2].unsqueeze(1) * resnet2_out
            )
            return weighted_out
        
        else:
            return (resnet_out + effnet_out + resnet2_out) / 3

In [13]:
def train_ensemble_model(
    resnet_train_loader, 
    effnet_train_loader,
    resnet2_train_loader,
    resnet_test_loader, 
    effnet_test_loader,
    resnet2_test_loader,
    num_classes, 
    epochs, 
    learning_rate, 
    threshold,
    ensemble_type) :

    model = EnsembleModel(num_classes, ensemble_type).to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.75)
    
    # training loop
    all_loss = []
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        
        dataloaders = [resnet_train_loader, effnet_train_loader, resnet2_train_loader]
        max_loader_len = max(len(loader) for loader in dataloaders)
        
        model.train()
        total_loss = 0
        for batch_idx in range(max_loader_len):
            # cycle through dataloaders
            X_resnet, y_resnet = next(cycle(resnet_train_loader))
            X_effnet, y_effnet = next(cycle(effnet_train_loader))
            X_resnet2, y_resnet2 = next(cycle(resnet2_train_loader))
            
            X_resnet = X_resnet.to(device)
            X_effnet = X_effnet.to(device)
            X_resnet2 = X_resnet2.to(device)
            
            y_resnet = y_resnet.to(device)
            y_effnet = y_effnet.to(device)
            y_resnet2 = y_resnet2.to(device)
            
            # get predictions
            resnet_out = model.resnet(X_resnet)
            effnet_out = model.effnet(X_effnet)
            resnet2_out = model.resnet2(X_resnet2)
            
            # compute losses
            loss_resnet = loss_fn(resnet_out, y_resnet)
            loss_effnet = loss_fn(effnet_out, y_effnet)
            loss_resnet2 = loss_fn(resnet2_out, y_resnet2)
            
            # total loss
            loss = (loss_resnet + loss_effnet + loss_resnet2) / 3
            
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 64 == 0:
                print(f"loss: {loss.item():>7f}")
        
        # validation loop
        model.eval()
        test_loss, f2 = 0, 0
        num_batches = min(len(resnet_test_loader), len(effnet_test_loader), len(resnet2_test_loader))
        
        with torch.no_grad():
            for (X_resnet, y_resnet), (X_effnet, y_effnet), (X_resnet2, y_resnet2) in zip(
                resnet_test_loader, effnet_test_loader, resnet2_test_loader
            ):
                X_resnet = X_resnet.to(device)
                X_effnet = X_effnet.to(device)
                X_resnet2 = X_resnet2.to(device)
                
                y_resnet = y_resnet.to(device)
                y_effnet = y_effnet.to(device)
                y_resnet2 = y_resnet2.to(device)
                
                # get model predictions
                pred_resnet = model.resnet(X_resnet)
                pred_effnet = model.effnet(X_effnet)
                pred_resnet2 = model.resnet2(X_resnet2)
                
                # ensemble prediction
                if model.ensemble_type == 'weighted':
                    normalized_weights = F.softmax(model.weights, dim=0)
                    pred = (
                        normalized_weights[0] * pred_resnet + 
                        normalized_weights[1] * pred_effnet + 
                        normalized_weights[2] * pred_resnet2
                    )

                elif model.ensemble_type == 'dudani':
                    # dudani's rule dynamic weights
                    distances = torch.stack([
                        -torch.max(torch.sigmoid(pred_resnet), dim=1)[0],
                        -torch.max(torch.sigmoid(pred_effnet), dim=1)[0],
                        -torch.max(torch.sigmoid(pred_resnet2), dim=1)[0]
                    ], dim=1)

                    d1, _ = torch.min(distances, dim=1, keepdim=True)
                    dk, _ = torch.max(distances, dim=1, keepdim=True)
                    
                    dudani_weights = (dk - distances) / (dk - d1 + 1e-9)
                    dudani_weights[distances == d1] = 1.0
                    dudani_weights[distances == dk] = 0.0
                    dudani_weights = dudani_weights / dudani_weights.sum(dim=1, keepdim=True)
                    
                    pred = (
                        dudani_weights[:, 0].unsqueeze(1) * pred_resnet +
                        dudani_weights[:, 1].unsqueeze(1) * pred_effnet +
                        dudani_weights[:, 2].unsqueeze(1) * pred_resnet2
                    )

                elif model.ensemble_type == 'shepard':
                    # shepard's rule dynamic weights
                    distances = torch.stack([
                        -torch.max(torch.sigmoid(pred_resnet), dim=1)[0],
                        -torch.max(torch.sigmoid(pred_effnet), dim=1)[0],
                        -torch.max(torch.sigmoid(pred_resnet2), dim=1)[0]
                    ], dim=1)  # Shape: [batch_size, num_models]
                    
                    # shepard weights
                    shepard_weights = torch.exp(-model.a * torch.abs(distances) ** model.b)
                    shepard_weights = shepard_weights / shepard_weights.sum(dim=1, keepdim=True)
                    
                    # ensemble prediction
                    pred = (
                        shepard_weights[:, 0].unsqueeze(1) * pred_resnet + 
                        shepard_weights[:, 1].unsqueeze(1) * pred_effnet + 
                        shepard_weights[:, 2].unsqueeze(1) * pred_resnet2
                    )

                else:
                    pred = (pred_resnet + pred_effnet + pred_resnet2) / 3
                
                # compute test loss
                test_loss += loss_fn(pred, y_resnet).item()
                
                # calculate f2 score
                pred_tags = torch.sigmoid(pred).cpu().numpy() > threshold
                true_tags = y_resnet.cpu().numpy()
                f2 += fbeta_score(true_tags, pred_tags, beta=2, average='micro')
        
        test_loss /= num_batches
        f2 /= num_batches
        
        print(f"Test Error: \n f2 score: {f2:.5f}, avg loss: {test_loss:>8f} \n")
        all_loss.append(test_loss)

        scheduler.step()
        print(scheduler.get_last_lr())
    
    return model, all_loss

In [14]:
learning_rate = 1e-4
epochs = 1
threshold = 0.24

In [15]:
ensemble_model, loss_history = train_ensemble_model(
    resnet_train_loader,
    effnet_train_loader,
    resnet2_train_loader,
    resnet_test_loader,
    effnet_test_loader,
    resnet2_test_loader,
    num_classes=num_classes,
    epochs=epochs,
    learning_rate=learning_rate,
    threshold=threshold,
    ensemble_type='weighted'
)

Epoch 1
-------------------------------
loss: 0.078989
loss: 0.039761
loss: 0.045638
loss: 0.046747
loss: 0.039083
loss: 0.042317
loss: 0.032731
loss: 0.026938
loss: 0.029724
loss: 0.028663
loss: 0.027427
loss: 0.045305
loss: 0.022390
loss: 0.029219
loss: 0.028084
loss: 0.023539
loss: 0.022054
loss: 0.021091
Test Error: 
 f2 score: 0.97803, avg loss: 0.027023 

[7.500000000000001e-05]


In [16]:
torch.save(ensemble_model.state_dict(), "LP_OS_ML_OS_ResNet50.pth")