In [None]:
# lovasz softmax
from lovasz_softmax import lovasz_softmax

ModuleNotFoundError: No module named 'lovasz_softmax'

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)
save_path = None

In [None]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image

class CustomDataset(Dataset):

  def __init__(self, baseDataset):
    self.ds = baseDataset
    self.transform = transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomApply([transforms.RandomRotation(degrees=10)], p = 0.5),
    transforms.ToTensor()
    ])


  def __len__(self):
    return len(self.ds)

  def __getitem__(self, idx):
    image, label = self.ds[idx]
    image = Image.fromarray((image * 255).astype('uint8'))  # Assuming array is scaled [0, 1]
    image = self.transform(image)
    return image, label



In [None]:
import torch
import torch.nn.functional as F

class WeightedCrossEntropyLoss(torch.nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, outputs, targets):
        return F.cross_entropy(outputs, targets, weight = self.class_weights)

class CombinedLoss(torch.nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.wce_loss = WeightedCrossEntropyLoss(class_weights)
        self.lovasz_loss = lovasz_softmax

    def forward(self, outputs, targets):
        loss_wce = self.wce_loss(inputs, targets)
        loss_lovasz = self.lovasz_loss(outputs, targets)
        return loss_wce + loss_lovasz

In [3]:
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

save_path = None # TODO - set it as the name of the folder in Google Drive where you want to save all the model weights

def train_model(model, train_loader, val_loader, class_weights, num_epochs, is_graph = True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = SGD(model.parameters(), lr = 0.01, momentum= 0.9, weight_decay = 0.0001)
    scheduler = StepLR(optimizer, 1, gamma = 0.01)

    loss_fn = CombinedLoss(class_weights)
    loss_values = []

    train_acc_values, val_acc_values = [], []

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        correct_train, total_train = 0, 0
        for inputs, targets, _, _ in train_loader:

            # inputs -> B * 5 * H * W, targets: B * H * W

            # targets[v, u] -> B
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)  # B * 20 * H * W
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

             # Calculate training accuracy
            if is_graph:
              _, predicted = torch.max(outputs, 1)
              correct_train += (predicted == targets).sum().item()
              total_train += targets.numel()

        # Validation phase
        if is_graph:
          model.eval()
          correct_val, total_val = 0, 0

          with torch.no_grad():
              for inputs, targets, _, _ in val_loader:
                  inputs, targets = inputs.to(device), targets.to(device)
                  outputs = model(inputs)
                  _, predicted = torch.max(outputs, 1)
                  correct_val += (predicted == targets).sum().item()
                  total_val += targets.numel()

          train_accuracy, val_accuracy = 100 * correct_val / total_val, 100 * correct_train / total_train

          train_acc_values.append(train_accuracy)
          val_acc_values.append(val_accuracy)

          # saving weights every 10 epochs
          if (epoch + 1) % 10 == 0:
              torch.save(model.state_dict(), f'{save_path}/model_epoch_{epoch + 1}.pth')

        scheduler.step()

        epoch_loss = total_loss / len(train_loader)




        if is_graph:
            loss_values.append(epoch_loss)

        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

    if is_graph:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(range(1, num_epochs + 1), loss_values)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')

        plt.subplot(1, 2, 2)
        plt.plot(range(1, num_epochs + 1), train_acc_values, label='Train')
        plt.plot(range(1, num_epochs + 1), val_acc_values, label='Validation')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Training and Validation Accuracy')
        plt.legend()
        plt.show()

In [None]:
def get_class_weights(dataloader, num_classes = 20):
    class_counts = {}

    for _, labels in dataloader:
        # assuming label is int
        for label in labels:
            class_counts[label] = class_counts.get(label, 0) + 1


    class_frequencies = torch.zeros(num_classes)
    for key, value in class_counts.items():
        class_frequencies[key] = value

    weights = 1.0 / torch.sqrt(class_frequencies)
    weights[class_frequencies == 0] = 0

    return weights


In [None]:
from torch.utils.data import DataLoader

trainset = None # TODO
testset = None #TODO

batch_size = 24

# the purpose of the CustomDataset is to apply the random flips and rotations. If your dataset already takes that into account, you may not need to convert your dataset into CustomDataset format

train_loader = DataLoader(CustomDataset(trainset), batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = DataLoader(CustomDataset(testset), batch_size=batch_size,
                                         shuffle=False, num_workers=2)

val_loader = DataLoader(CustomDataset(testset), batch_size=batch_size,
                                         shuffle=False, num_workers=2)

model = None # TODO: create model instance


num_epochs = 10 # TODO: play with this hyperparameter
class_weights = get_class_weights(train_loader)
train_model(model, train_loader, val_loader, class_weights, num_epochs)

TypeError: object of type 'NoneType' has no len()