In [5]:
import os
import numpy as np 
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [6]:
#list of labels: double_plant, water, waterway, etc.
label_names = os.listdir("./dataset/train/labels")
IMAGE_SIZE = 512

In [7]:
from dataset import dataset

trainset = dataset("dataset/train")
valset = dataset("dataset/val")


In [8]:
batch_size = 5

train_loader = DataLoader(trainset, batch_size = batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size = batch_size, shuffle=True, num_workers=2)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from model import MyModel

model = MyModel()
model = model.to(device)

In [None]:
from torch.utils.tensorboard import SummaryWriter

image, _ = next(iter(train_loader))
image = image.to(device)
writer = SummaryWriter("logs")
writer.add_graph(model, image)

del image

In [12]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)


In [24]:
def train_loop(dataloader, epoch):
    train_loss = 0
    train_miou = 0
    train_total = 0
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)
        predictions = model(X)
        pred = [F.interpolate(pred, size=IMAGE_SIZE) for pred in predictions]
        pred = torch.stack(pred)
        pred = torch.mean(pred, dim=0)
        
        loss = loss_fn(pred, y)
        
        with torch.no_grad():
            pred = torch.round(pred)
            intersection = torch.logical_and(pred, y).sum(dim=(2,3))
            union = torch.logical_or(pred, y).sum(dim=(2,3))
            
            #To make sure we are dividing by zero in case nothing is detected
            union = torch.where(union==0, 1, union)
            
            train_miou += torch.sum(intersection/union) / len(label_names)
            
        train_loss += loss
        train_total += len(y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch%5000 == 4999:
            val_total = 0
            val_loss = 0
            val_miou = 0
            with torch.no_grad():
                for batch, (X, y) in enumerate(val_loader):
                    X, y = X.to(device), y.to(device)

                    predictions = model(X)
                    pred = [F.interpolate(pred, size=IMAGE_SIZE) for pred in predictions]
                    pred = torch.stack(pred)
                    pred = torch.mean(pred, dim=0)

                    loss = loss_fn(pred, y)
                    val_loss += loss
                    val_total += len(y)

                    pred = torch.round(pred)
                    intersection = torch.logical_and(pred, y).sum(dim=(2,3))
                    union = torch.logical_or(pred, y).sum(dim=(2,3))
                    val_miou += torch.sum(intersection/union) / len(label_names)
                
                print(f"Epoch: {epoch+1}, {batch+1}/{len(train_loader)}, Train Loss: {train_loss/train_total} \
                            , Validation Loss: {val_loss/val_total}, \
                            Train MIOU: {train_miou/train_total}, \
                            Validation MIOU: {val_miou/val_total}")
                
                writer.add_scalars("Loss", {
                      "Training": train_loss/train_total,
                      "Validation": val_loss/val_total
                  }, epoch*len(train_loader)+ batch/batch_size )
                
                writer.add_scalars("MIOU", {
                      "Training": train_miou/train_total,
                      "Validation": val_miou/val_total
                  }, epoch*len(train_loader)+ batch/batch_size )

            
            train_loss = 0
            train_miou = 0
            train_total = 0
                

In [None]:
epochs = 100

for epoch in range(epochs):
    train_loop(train_loader, epoch)
    torch.save(model, f"./{epoch}.pt")