In [None]:
!pip install wandb gdown timm h5py

In [1]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
#import geopandas as gpd
import wandb
import matplotlib.pyplot as plt
import h5py
import time
from convnextv2_unet import ConvNeXtV2_unet
from simple_unet import UNet
from training_utils import evaluate

In [2]:
torch.cuda.is_available()

True

In [None]:
!gdown --id 15qEZ6nMJ1xLD5l4VQ3bUR21FVDq_QTBU
!gdown --id 1a3cGzN6xncf7BH3TmFwqHtG9tOk_gXO7

In [3]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmoraesd90[0m ([33mt5_ssl4eo[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
class HDF5Dataset(Dataset):
    def __init__(self, hdf5_path, transform=None):
        self.hdf5_path = hdf5_path
        self.transform = transform
        self.h5file = h5py.File(hdf5_path, 'r')
        self.size = self.h5file['labels'].shape[0] #.size
        self.data = torch.from_numpy(self.h5file['crops'][:].astype(np.float32) / 10000.0)
        self.labels = torch.as_tensor(self.h5file['labels'][:],dtype=torch.long)
        self.labels[self.labels==255]=20 #added for test with weights - TEMP
        
    def __len__(self):
        return self.size
    
    
    def __getitem__(self, idx):
        crop = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            #label = label.unsqueeze(0) # Add a channel dimension to the label
            # Concatenate the label as an additional channel to the crop
            combined = torch.cat((crop, label.unsqueeze(0)), dim=0)
            combined = self.transform(combined)

            # Split the crop and label back into separate tensors
            crop = combined[:-1]  # All but the last channel
            label = combined[-1].long()  # The last channel

        return crop, label
    

In [7]:
from torchvision.transforms import v2
train_transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    #v2.RandomSolarize(0.05)
    #v2.ToTensor()  # Convert image to PyTorch tensor
])
#train_transforms = None


In [10]:
#from torchinfo import summary
#model = model.cuda() #send model to device
#summary(model)

In [11]:
train_set = HDF5Dataset("crops_train_seg_all_sel.hdf5", transform=train_transforms)
test_set = HDF5Dataset("crops_test_seg_all_sel.hdf5")


In [12]:
def trainModel(model, train_loader, test_loader, test_eval, optimizer, criterion, config_wandb):

    wandb.init(project="ifn-weakly-supervised-seg-v2-vast", config=config_wandb)
    
    for epoch in range(config_wandb['epochs']):
        tstart = time.time()
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            # Transfer to GPU 
            inputs, labels = inputs.cuda(), labels.cuda()

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs, dim=1)
            mask = labels!=20
            labels_selected = labels[mask]
            predicted_selected = predicted[mask]
            correct += (predicted_selected == labels_selected).sum().item()  # Sum the correct predictions
            total += labels_selected.size(0)

            # Print statistics
            running_loss += loss.item()
            
        avg_loss = running_loss / len(train_loader)
        if test_eval:
            val_loss, val_accuracy = evaluate(model, test_loader, criterion, num_classes)
        else:
            val_loss, val_accuracy = (0, 0)
        wandb.log({"epoch": epoch, "train_loss": avg_loss, "train_acc":correct/total, "val_loss":val_loss, "val_acc":val_accuracy})
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Acc: {correct / total :.4f}, Test_eval: {str(test_eval)}, Time/epoch: {round((time.time()-tstart)/60,2)}min")
    
    wandb.finish()


In [14]:
from itertools import product

num_classes = 21

ws = [1 for i in range(21)] #loss weights
ws[-1] = 0
ws = torch.tensor(ws).float().cuda()

num_epochs = 150
criterion = nn.CrossEntropyLoss(weight=ws)
test_eval = True

batch_size_list = [32, 64]
lr_list = [0.01, 0.001, 0.0001]
depths_list = [
    [2, 2, 6, 2],
]
dims_list = [
    [40, 80, 160, 320],
]

for batch_size, lr, depth, dim in product(batch_size_list, lr_list, depths_list, dims_list):
    #set data loaders, which will vary according to the batch_size
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=14,pin_memory=True,persistent_workers=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=14,pin_memory=True)
    #set model, which will vary according to the depth and dim
    model = None #first make sure no previous model was initialized
    model = UNet(in_channels=36, out_channels=num_classes)
    #model = ConvNeXtV2_unet(in_chans=36, num_classes=20, depths=depth, dims=dim, use_orig_stem=False)
    #model = ConvNeXtV2(in_chans=36, num_classes=20, depths=depth, dims=dim)
    #set optimizer, which will vary according to the learning rate
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    #set wandb config dictionary, which will vary depending on the parameters
    config_wandb = {
    "optimizer": "Adam", #fixed
    "criterion": "CrossEntropyLoss_ws", #fixed
    "learning_rate": lr, 
    "epochs": num_epochs, #fixed
    "batch_size": batch_size,
    "augmentations":"H&V Flip", #fixed
    "architecture":"UNet_Simple",#"ConvNextV2_mod", #fixed
    "depth": depth,
    "dim": dim
    }

    model = model.cuda() #send model to device

    #train model
    trainModel(model, train_loader, test_loader, test_eval, optimizer, criterion, config_wandb)

Epoch [1/200], Average Loss: 2.6746, Acc: 0.1252, Test_eval: True, Time/epoch: 0.61min
Epoch [2/200], Average Loss: 2.5753, Acc: 0.1583, Test_eval: True, Time/epoch: 0.54min
Epoch [3/200], Average Loss: 2.5218, Acc: 0.1754, Test_eval: True, Time/epoch: 0.53min
Epoch [4/200], Average Loss: 2.4930, Acc: 0.1837, Test_eval: True, Time/epoch: 0.55min
Epoch [5/200], Average Loss: 2.4649, Acc: 0.1949, Test_eval: True, Time/epoch: 0.53min
Epoch [6/200], Average Loss: 2.4406, Acc: 0.2046, Test_eval: True, Time/epoch: 0.53min
Epoch [7/200], Average Loss: 2.4268, Acc: 0.2115, Test_eval: True, Time/epoch: 0.53min
Epoch [8/200], Average Loss: 2.4034, Acc: 0.2173, Test_eval: True, Time/epoch: 0.53min
Epoch [9/200], Average Loss: 2.3857, Acc: 0.2232, Test_eval: True, Time/epoch: 0.53min
Epoch [10/200], Average Loss: 2.3714, Acc: 0.2281, Test_eval: True, Time/epoch: 0.53min
Epoch [11/200], Average Loss: 2.3615, Acc: 0.2324, Test_eval: True, Time/epoch: 0.53min
Epoch [12/200], Average Loss: 2.3447, Acc

KeyboardInterrupt: 

In [15]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_acc,▁▂▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,█▇▆▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁
val_acc,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██▇███▇
val_loss,█▆▅▄▄▃▄▃▃▃▂▂▂▁▂▂▂▁▁▁▁▂▁▁▂▂▁▁▂▁▁▁▃▁▁▁▁▁▂▂

0,1
epoch,96.0
train_acc,0.37039
train_loss,1.89854
val_acc,0.34236
val_loss,2.17764
