In [1]:
import torch
import dataset
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.nn import BCELoss
from unet_model import UNet
from torchgeo.samplers import RandomGeoSampler
from torchgeo.datasets import Sentinel2,RasterDataset, stack_samples
from torch.utils.tensorboard import SummaryWriter

In [2]:
# def experiment(test_site):
#     training_sites = ['herschel', 'horton', 'kolguev', 'lena', 'unvalidated']
#     training_sites.remove(test_site)
    
#     training_data = dataset.Permafrost(training_sites)
#     test_data = dataset.Permafrost([test_site])
#     return test_data, training_data

# test_data, training_data = experiment('herschel')

In [18]:
num_epochs = 100
learning_rate = 0.01

train= RasterDataset(root='../all_data/unvalidated')
sampler = RandomGeoSampler(train, size=32, length=100)
train_dataloader = DataLoader(train, num_workers = 4, sampler=sampler, collate_fn=stack_samples)

val = RasterDataset(root='../all_data/lena')
sampler = RandomGeoSampler(val, size=32, length=10)
val_dataloader = DataLoader(val, num_workers = 4, sampler=sampler, collate_fn=stack_samples)

model = UNet(n_channels =7, n_classes=1)
optimizer = SGD(model.parameters(), learning_rate)

In [None]:
writer = SummaryWriter()
for e in range(num_epochs):
    print("Epoch ", e)
    train_loss = 0.0
    for sample in tqdm(train_dataloader):
        data = sample["image"]
        labels = data[:,0,:,:]
        img = data[:,1:,:,:]
        pred = model(img.float())
        loss = BCELoss()
        
        output = loss(torch.reshape(pred, (1,32,32)), labels.float())
        writer.add_scalar("Loss/train", output, epoch)
        
        optimizer.zero_grad()
        optimizer.step()
        train_loss += output.item()
    
    valid_loss = 0.0
    model.eval()     # Optional when not using Model Specific layer
    for val_sample in val_dataloader:
#         if torch.cuda.is_available():
#             data, labels = data.cuda(), labels.cuda()
        val_data = val_sample["image"]
        val_labels = val_data[:,0,:,:]
        val_img = val_data[:,1:,:,:]
        
        val_pred = model(val_img.float())
        val_output = loss(torch.reshape(val_pred, (1,32,32)), val_labels.float())
        valid_loss = val_output.item() * val_data.size(0)

    print(f'Epoch {e+1} \t\t Training Loss: {train_loss / len(train_dataloader)} \t\t Validation Loss: {valid_loss / len(val_dataloader)}')

Epoch  0


 72%|███████████████████████████████████████████████████████████                       | 72/100 [00:05<00:02, 11.43it/s]

In [None]:
from torchgeo.models import resnet50

model = resnet50()