In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset
from glob import glob
from skimage import io
import os
from torchvision import datasets, transforms
import matplotlib
import os
os.cpu_count()

  from .autonotebook import tqdm as notebook_tqdm


20

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [3]:
from dataset.data_loaders import *
from dataset.utils.plot_utils import plot_s1s2_tensors, save_s1s2_tensors_plot
from config import *
from train_utils import *

In [5]:
from temporalgan.temporal_gan_v3_gen import Generator as GeneratorV3
from temporalgan.temporal_gan_v2_gen import Generator as GeneratorV2
from temporalgan.temporal_gan_v1_gen import Generator as GeneratorV1
from temporalgan.temporal_gan_v2_disc import Discriminator as DiscriminatorV2
from temporalgan.temporal_gan_v1_disc import Discriminator as DiscriminatorV1
from temporalgan.lossfunciton.loss_function import WeightedL1Loss

In [12]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [6]:
def main():
    disc = DiscriminatorV1(s2_in_channels=S2_INCHANNELS, s1_in_channels=S1_INCHANNELS).to(DEVICE)
    gen = GeneratorV2(s2_in_channels=S2_INCHANNELS, s1_in_channels= S1_INCHANNELS, features=64,pam_downsample=2).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    if WEIGHTED_LOSS:
        L1_LOSS = WeightedL1Loss(change_weight=5)
    else:
        L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    transform = transforms.Compose([S2S1Normalize(),myToTensor()])


    train_dataset = Sen12Dataset(s1_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s1_imgs\\test",
                                s2_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s2_imgs\\test",
                                s1_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s1_imgs\\test",
                                s2_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s2_imgs\\test",
                                transform=transform,
                                two_way=False)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
#     val_dataset = MapDataset(root_dir=VAL_DIR)
#     val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(1, NUM_EPOCHS+1):
        print(f"Epoch: {epoch}")
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % SAVE_EVERY_EPOCH == 0 or True:
            #save_checkpoint(epoch,gen, opt_gen, filename=CHECKPOINT_GEN)
            #save_checkpoint(epoch,disc, opt_disc, filename=CHECKPOINT_DISC)
            save_some_examples(gen, train_dataset, epoch, folder="evaluation",cm_input=INPUT_CHANGE_MAP, img_indx=1)
            
        if epoch == NUM_EPOCHS:
            return gen



In [7]:
matplotlib.use('Agg') # This refrains matplot lib form showing the plotted resualts below the cell
gen_model = main()

Epoch: 1


100%|██████████| 4/4 [00:09<00:00,  2.25s/it, D_fake=0.283, D_real=0.546, G_loss=408, L1=407]


In [10]:
transform = transforms.Compose([S2S1Normalize(),myToTensor()])
test_dataset = Sen12Dataset(s1_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s1_imgs\\test",
                            s2_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s2_imgs\\test",
                            s1_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s1_imgs\\test",
                            s2_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s2_imgs\\test",
                            transform=transform,
                            two_way=False)

In [11]:
save_some_examples(gen_model, test_dataset, NUM_EPOCHS, folder="evaluation",cm_input=INPUT_CHANGE_MAP, img_indx=2)