In [None]:
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from landcoverDataset import LandCoverDataset
import matplotlib.pyplot as plt

In [None]:
# Define dataloader and other config
dataset = LandCoverDataset(r'dataSourcev2/pca-8.npy')
loader = torch.utils.data.DataLoader(dataset, num_workers=0, batch_size=1, shuffle=False)
n_class=7
n_ch=9

------- FOR SEGMENTATION ---------

In [None]:
# Define model - segmentation
from model_seg import UNetResNet
moduleSeg = UNetResNet(num_classes=n_class, in_channels=n_ch)
moduleSeg.load_state_dict(torch.load('trained/seg-resnet.pth'))
moduleSeg = moduleSeg.cuda().eval()
softmax = torch.nn.Softmax2d()
savePath = Path(r'D:\pca\8-pca')

In [None]:
# Run inference
with torch.no_grad():
    for i, (img, mask) in enumerate(loader):
        pred = moduleSeg(img.float().cuda())
        pred = softmax(pred)
        pred = pred.argmax(1).data.cpu().numpy().squeeze().astype(np.uint8)
        img = Image.fromarray(pred)
        img.save(savePath / f'{i}.png')

------- FOR RECONSTRUCTION ---------

In [None]:
# Define model - AE
from model_basic_resnet import VanillaVAE
module =  VanillaVAE(9,[64,32,16,8])
module.load_state_dict(torch.load('trained/ae-seg/8.pth'))
module = module.cuda().eval()
savePath = Path(r'D:\pca\reconstructed\ae-seg-8')

In [None]:
# Run inference and save
with torch.no_grad():
    for i, (img, mask) in enumerate(loader):
        pred = module(img.float().cuda())
        pred = softmax(moduleSeg(pred)).argmax(1)
        pred = pred.data.cpu().numpy().squeeze().astype(np.uint8)
        img = Image.fromarray(pred)
        img.save(savePath / f'{i}.png')

In [None]:
# Reconstruction
imgs = []
with torch.no_grad():
    for i , (img, mask) in enumerate(loader):
            pred = module(img.float().cuda())
            pred = np.concatenate((pred.data.cpu().numpy().squeeze(),mask), axis=0)
            imgs.append(pred)
    imgs = np.stack(imgs)
    np.save(savePath / '8.npy', imgs)