In [None]:
import torch
import numpy as np
from landcoverDataset import LandCoverDataset

In [None]:
# Define dataloader and other config
dataset = LandCoverDataset(r'dataSourcev2/original.npy')
loader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=4, shuffle=False)
n_class=7
n_ch=9

------ TRAIN AUTOENCODER --------

In [None]:
from model_basic_resnet import VanillaVAE
model = VanillaVAE(9, [8,4,2,1]).cuda().train()
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch_idx in range(50):

    loss_batches = []
    for batch_idx, data in enumerate(loader):
    
        img, mask = data
        img = img.cuda()

        y = model(img)
        loss = loss_function(y, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_batches.append(loss.data.cpu().numpy())

    print('epoch: ' + str(epoch_idx) + ' training loss: ' + str(np.sum(loss_batches)))  

torch.save(model.cpu().state_dict(), 'trained/ae-resnet.pth')
print('model saved')

------- TRAIN AUTOENCODER + SEG ----------

In [None]:
from model_basic_resnet import VanillaVAE
from model_seg import UNetResNet
module = VanillaVAE(9, [8,4,2,1]).cuda()
module.load_state_dict(torch.load('trained/ae/1.pth'))
module = module.train()

segmodule = UNetResNet(7,9).cuda()
segmodule.load_state_dict(torch.load('trained/seg-resnet.pth'))
segmodule.requires_grad_(False)

lossFun1 = torch.nn.MSELoss()
lossFun2 = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(module.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 33, 0.5)

for epoch_idx in range(100):

    loss_batches = []
    for batch_idx, data in enumerate(loader):
    
        img, mask = data
        img = img.cuda()
        mask = mask.long().cuda()

        y = module(img)
        loss1 = lossFun1(y, img)
        segOutput = segmodule(y)
        loss2 = lossFun2(segOutput, mask)
        optimizer.zero_grad()
        loss = loss1 + 0.01*loss2
        loss.backward()
        optimizer.step()

        loss_batches.append(loss.data.cpu().numpy())
    scheduler.step()

    print('epoch: ' + str(epoch_idx) + ' training loss: ' + str(np.sum(loss_batches)))

torch.save(module.cpu().state_dict(), 'trained/ae-seg/1.pth')
print('model saved')

------ TRAIN SEGMENTATION --------

In [None]:
from model_seg import UNetResNet
model = UNetResNet(n_class,n_ch).cuda().train()
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch_idx in range(100):

    loss_batches = []
    for batch_idx, data in enumerate(loader):
    
        imgs, masks = data
        imgs = imgs.cuda()
        masks = masks.long().cuda()

        y = model(imgs)
        loss = loss_function(y, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_batches.append(loss.data.cpu().numpy())

    print('epoch: ' + str(epoch_idx) + ' training loss: ' + str(np.sum(loss_batches)))

torch.save(model.cpu().state_dict(), 'trained/seg-resnet.pth')
print('model saved')