In [1]:
import os, glob, json, torch, timm

import numpy as np
import pandas as pd
import rasterio as rio

ModuleNotFoundError: No module named 'torch._C'

In [2]:
from torch.utils.data import Dataset, DataLoader


In [3]:
class EuroSATDataset(Dataset):
    def __init__(self, mode, root_dir):
        vec_file = f"{root_dir}/vectors/{mode}.csv"
        meta_file = f"{root_dir}/vectors/metadata.json"
        with open(meta_file) as out:
            task_meta = json.load(out)

        classes = [lbl_meta["options"] for lbl_meta in task_meta["label:metadata"]][0]
        cls_idx_map = {cls: idx for idx, cls in enumerate(classes)}

        vec_df = pd.read_csv(vec_file)
        vec_df["image"] = vec_df["image:01"].apply(lambda x: f'{root_dir}/rasters/{x.split("/")[-1]}')
        vec_df["label"] = vec_df["land-use-land-cover-class"].apply(lambda x: cls_idx_map[x])
        vec_df.drop(['image-id','image:01','date:01','type','geometry','land-use-land-cover-class'],axis=1,inplace=True)

        self.vec_df = vec_df
        self.classes = classes

    def __len__(self):
        return len(self.vec_df)

    def __getitem__(self, idx):
        df_entry = self.vec_df.loc[idx]
        smpl_map = {
            "image": rio.open(df_entry["image"]).read(),
            "label": df_entry["label"]
        }
        
        return smpl_map

In [4]:
root_dir = "/home/akash/Downloads/EuroSAT"
num_workers = 4
batch_size = 512

rstr_root_dir = f"{root_dir}/rasters"
vctr_root_dir = f"{root_dir}/vectors"

train_ds = EuroSATDataset(root_dir=root_dir, mode="train")
val_ds = EuroSATDataset(root_dir=root_dir, mode="val")

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [5]:
from torchvision.models import resnet50
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

In [6]:
device = torch.device("cuda")
classes = train_ds.classes
num_classes = len(classes)

lr = 0.01
momemtum = 0.9
weight_decay = 1e-4

In [None]:
model = resnet50(num_classes=num_classes)

model.to(device)

In [8]:
criterion = CrossEntropyLoss().to(device)
optimizer = SGD(model.parameters(),lr=lr,momentum=momemtum,weight_decay=weight_decay)
scheduler = StepLR(optimizer,step_size=30,gamma=0.1)

In [9]:
def update_losses(losses, loss):
    losses += [loss.item()]

def update_confusion_matrix(conf_matrix, outputs, labels):
    with torch.no_grad():
        maxk, batch_size = 1, labels.size(0)

        _, pred = outputs.topk(maxk,1,True, True)
        
        pred = pred.t().tolist()[0]
        truth = labels.tolist()

        for tidx, pidx in zip(truth,pred):
            conf_matrix[tidx,pidx] += 1
        

In [10]:
def print_loss_metrics(losses, conf_matrix, num_classes):
    avg_loss = sum(losses)/len(losses)

    recall = [conf_matrix[i,i]/sum(conf_matrix[:,i]) for i in range(num_classes)]
    precision = [conf_matrix[i,i]/sum(conf_matrix[i,:]) for i in range(num_classes)]
    f1_score = [(2*precision[i]*recall[i])/(precision[i]+recall[i]) for i in range(num_classes)]

    recall_map = {classes[idx]: recall[idx] for idx in range(num_classes)}
    precision_map = {classes[idx]: precision[idx] for idx in range(num_classes)}
    f1_score_map = {classes[idx]: f1_score[idx] for idx in range(num_classes)}

    metric_maps = {
        "precision": precision_map,
        "recall": recall_map,
        "f1 score": f1_score_map
    }

    avg_metric_map = {
        "precision": sum(precision)/num_classes,
        "recall": sum(recall)/num_classes,
        "f1 score": sum(f1_score)/num_classes,
        "accuracy": sum([conf_matrix[i,i] for i in range(num_classes)])/conf_matrix.sum()
    }

    print(f"\t\t loss: {avg_loss}")

    print("\t\t metrics:")
    
    print("\t\t\t confusion matrix:")
    print("\t\t\t "+f"{conf_matrix}".replace("\n","\n\t\t\t"))
    print("")

    print("\t\t\t class level metrics:")
    for met_key in metric_maps:
        print(f"\t\t\t\t {met_key}")
        for cls, metric in metric_maps[met_key].items():
            print(f"\t\t\t\t\t {cls}: {metric}")
    print("")

    print("\t\t\t overall metrics:")
    for met_key in avg_metric_map:
        print(f"\t\t\t\t {met_key}: {avg_metric_map[met_key]}")
    print("")

    return avg_metric_map["f1 score"]


In [11]:
max_val_f1 = -1
for epoch in range(1,5):
    print(f"epoch {epoch}")

    weights_path = "/home/akash/Downloads/EuroSAT/weights/resnet50_ckpt_{:03d}.pth".format(epoch)

    losses = []
    conf_matrix = np.zeros([num_classes,num_classes], dtype=np.uint32) 
    for idx, sample in enumerate(train_dl):
        inputs = sample["image"].float().to(device)
        labels = sample["label"].to(device)
        
        outputs = model(inputs)
        
        loss = criterion(outputs,labels)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        losses += [loss.item()]
        update_confusion_matrix(conf_matrix, outputs=outputs, labels=labels)

    print("\t Train : ")
    train_f1 = print_loss_metrics(losses=losses,conf_matrix=conf_matrix,num_classes=num_classes)

    losses = []
    conf_matrix = np.zeros([num_classes,num_classes], dtype=np.uint32)
    for idx, sample in enumerate(val_dl):
        inputs = sample["image"].float().to(device)
        labels = sample["label"].to(device)

        outputs = model(inputs)
        
        loss = criterion(outputs,labels)

        losses += [loss.item()]
        update_confusion_matrix(conf_matrix, outputs=outputs, labels=labels)
    
    print("\t Val : ")
    val_f1 = print_loss_metrics(losses=losses,conf_matrix=conf_matrix,num_classes=num_classes)
    
    if val_f1 > max_val_f1:
        max_val_f1 = val_f1
        torch.save(model.state_dict(), weights_path)




epoch 1
	 Train : 
		 loss: 2.01353091417357
		 metrics:
			 confusion matrix:
			 [[ 745   36  228   74  143   51  335  600  132   81]
			 [  64 1662   40    2   49   56    8  116  179  212]
			 [ 457  122  363  100   83  145  208  537  211  201]
			 [ 319  114  204  160  136   87  158  402  218  193]
			 [ 207    0   41   45 1409    0  119  150   28    3]
			 [  95  197  178   77   37  294   31  198  243  253]
			 [ 544   26  232   69  124   50  293  497   95   66]
			 [ 466   15  322  131  153   31  209  763  187  120]
			 [ 228  208  182  164   77  137   69  348  268  273]
			 [ 149  662  168   25   62  114   35  229  209  764]]

			 class level metrics:
				 precision
					 AnnualCrop: 0.30721649484536084
					 Forest: 0.6959798994974874
					 HerbaceousVegetation: 0.14956736711990112
					 Highway: 0.0803616273229533
					 Industrial: 0.7037962037962038
					 Pasture: 0.18340611353711792
					 PermanentCrop: 0.1467935871743487
					 Residential: 0.3183145598664998
					 River: 0.

In [16]:
conf_matrix.sum()

0

In [9]:
conf_matrix

NameError: name 'conf_matrix' is not defined