In [1]:
import os
IN_KAGGLE = 'KAGGLE_URL_BASE' in os.environ
IN_KAGGLE

False

In [2]:
if IN_KAGGLE:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value = user_secrets.get_secret("gittoken")

    !git clone https://{secret_value}@github.com/moienr/TemporalGAN.git

In [3]:
if IN_KAGGLE:
    import time
    import os
    sleep_time = 5
    while not os.path.exists("/kaggle/working/TemporalGAN"):
        print("didn't find the path, wating {sleep_time} more seconds...")
        time.sleep(sleep_time)
    print("path found...")
    import sys
    sys.path.append("/kaggle/working/TemporalGAN")

In [4]:
import torch
import numpy as np
from torch.utils.data import Dataset
from glob import glob
from skimage import io
import os
from torchvision import datasets, transforms
import matplotlib
import os
os.cpu_count()

  from .autonotebook import tqdm as notebook_tqdm


20

In [5]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [6]:
from dataset.data_loaders import *
from dataset.utils.plot_utils import plot_s1s2_tensors, save_s1s2_tensors_plot
#from config import *
from train_utils import *

In [7]:
from temporalgan.temporal_gan_v3_gen import Generator as GeneratorV3
from temporalgan.temporal_gan_v2_gen import Generator as GeneratorV2
from temporalgan.temporal_gan_v1_gen import Generator as GeneratorV1
from temporalgan.temporal_gan_v2_disc import Discriminator as DiscriminatorV2
from temporalgan.temporal_gan_v1_disc import Discriminator as DiscriminatorV1
from temporalgan.lossfunciton.loss_function import WeightedL1Loss

In [8]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [9]:
if IN_KAGGLE:
    s1_t1_dir_train= "/kaggle/input/s1s2-2021-v2/2021/s1_imgs/train"
    s2_t1_dir_train= "/kaggle/input/s1s2-2021-v2/2021/s2_imgs/train"
    s1_t2_dir_train= "/kaggle/input/s1s2-2019-v2/2019/s1_imgs/train"
    s2_t2_dir_train= "/kaggle/input/s1s2-2019-v2/2019/s2_imgs/train"
    s1_t1_dir_test = "/kaggle/input/s1s2-2021-v2/2021/s1_imgs/test"
    s2_t1_dir_test = "/kaggle/input/s1s2-2021-v2/2021/s2_imgs/test"
    s1_t2_dir_test = "/kaggle/input/s1s2-2019-v2/2019/s1_imgs/test"
    s2_t2_dir_test = "/kaggle/input/s1s2-2019-v2/2019/s2_imgs/test"
else:
    s1_t1_dir_train="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s1_imgs\\test"
    s2_t1_dir_train="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s2_imgs\\test"
    s1_t2_dir_train="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s1_imgs\\test"
    s2_t2_dir_train="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s2_imgs\\test"
    s1_t1_dir_test="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s1_imgs\\test"
    s2_t1_dir_test="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s2_imgs\\test"
    s1_t2_dir_test="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s1_imgs\\test"
    s2_t2_dir_test="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s2_imgs\\test"

In [10]:
TWO_WAY_DATASET = True if IN_KAGGLE else False
INPUT_CHANGE_MAP = True if IN_KAGGLE else False
S2_INCHANNELS = 12 if INPUT_CHANGE_MAP else 6
S1_INCHANNELS = 7 if INPUT_CHANGE_MAP else 1
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
NUM_WORKERS = 1 if IN_KAGGLE else 8
IMAGE_SIZE = 256
WEIGHTED_LOSS = True
L1_LAMBDA = 100
NUM_EPOCHS = 10 if IN_KAGGLE else 2
LOAD_MODEL = False
SAVE_MODEL = True if IN_KAGGLE else False
SAVE_EXAMPLE_PLOTS = True
EXAMPLES_TO_PLOT = [1,32,64,128,256,512,1024] if IN_KAGGLE else [1,2]
SAVE_MODEL_EVERY_EPOCH = 10
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"



In [11]:
import gc

In [12]:
def main():
    disc = DiscriminatorV1(s2_in_channels=S2_INCHANNELS, s1_in_channels=S1_INCHANNELS).to(DEVICE)
    gen = GeneratorV2(s2_in_channels=S2_INCHANNELS, s1_in_channels= S1_INCHANNELS, features=64,pam_downsample=2).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    if WEIGHTED_LOSS:
        L1_LOSS = WeightedL1Loss(change_weight=5)
    else:
        L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    transform = transforms.Compose([S2S1Normalize(),myToTensor(dtype=torch.float32)])


    train_dataset = Sen12Dataset(s1_t1_dir=s1_t1_dir_train,
                                 s2_t1_dir=s2_t1_dir_train,
                                 s1_t2_dir=s1_t2_dir_train,
                                 s2_t2_dir=s2_t2_dir_train,
                                 transform=transform,
                                 two_way=TWO_WAY_DATASET)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
#     val_dataset = MapDataset(root_dir=VAL_DIR)
#     val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(1, NUM_EPOCHS+1):
        print("\n\n" , end="")
        print(f"Epoch: {epoch}")
        print("")
        train_fn_no_tqdm(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler, weighted_loss= WEIGHTED_LOSS, cm_input=INPUT_CHANGE_MAP
        )

        if SAVE_MODEL and epoch % SAVE_MODEL_EVERY_EPOCH == 0:
            save_checkpoint(epoch,gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(epoch,disc, opt_disc, filename=CHECKPOINT_DISC)


        if SAVE_EXAMPLE_PLOTS:
            for img_i in EXAMPLES_TO_PLOT:
                save_some_examples(gen, train_dataset, epoch, folder="evaluation_plots",cm_input=INPUT_CHANGE_MAP, img_indx=img_i)

            
        gc.collect()
        torch.cuda.empty_cache()
            
    return gen


In [13]:
matplotlib.use('Agg') # This refrains matplot lib form showing the plotted resualts below the cell
gen_model = main()



Epoch: 1

Batch Number: 4 of 4

Epoch: 2

Batch Number: 4 of 4

In [14]:
transform = transforms.Compose([S2S1Normalize(),myToTensor()])
test_dataset = Sen12Dataset(s1_t1_dir=s1_t1_dir_test,
                            s2_t1_dir=s2_t1_dir_test,
                            s1_t2_dir=s1_t2_dir_test,
                            s2_t2_dir=s2_t2_dir_test,
                            transform=transform,
                            two_way=TWO_WAY_DATASET)
test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
    )

In [15]:
# save_some_examples(gen_model, test_dataset, NUM_EPOCHS, folder="evaluation",cm_input=INPUT_CHANGE_MAP, img_indx=2)

In [16]:
def eval_fn(gen, loader, ssim, psnr):
    gen.eval()
    loop = tqdm(loader, leave=True)
    weighted_ssim_list = []
    normal_ssim_list = []
    weighted_psnr_list = []
    normal_psnr_list = []
    
    for idx, (s2t2,s1t2,s2t1,s1t1,cm,rcm,s1cm) in enumerate(loop):
        s2t2,s1t2,s2t1,s1t1,cm,rcm,s1cm = s2t2.to(DEVICE),s1t2.to(DEVICE),s2t1.to(DEVICE),s1t1.to(DEVICE),cm.to(DEVICE),rcm.to(DEVICE),s1cm.to(DEVICE)
        s2t2,s1t2,s2t1,s1t1,cm,rcm,s1cm = s2t2.to(torch.float32),s1t2.to(torch.float32),s2t1.to(torch.float32),s1t1.to(torch.float32),cm.to(torch.float32),rcm.to(torch.float32),s1cm.to(torch.float32)
        if INPUT_CHANGE_MAP:
            s2t2 = torch.cat((s2t2, cm), dim=1)
            s1t1 = torch.cat((s1t1, rcm), dim=1)
        preds = gen(s2t2, s1t1)
        
        weighted_ssim = ssim((s1t2, preds), s1cm)
        weighted_ssim_list.append(weighted_ssim)
        normal_ssim = ssim((s1t2, preds))
        normal_ssim_list.append(normal_ssim)
        weighted_psnr = psnr((s1t2, preds), s1cm)
        weighted_psnr_list.append(weighted_psnr)
        normal_psnr = psnr((s1t2, preds))
        normal_psnr_list.append(normal_psnr)


        if idx % 1 == 0:
            loop.set_postfix(
                wssim_mean = torch.tensor(weighted_ssim_list).mean().item(),
                ssim_mean = torch.tensor(normal_ssim_list).mean().item(),
                wpsnr_mean = torch.tensor(weighted_psnr_list).mean().item(),
                psnr_mean = torch.tensor(normal_psnr_list).mean().item(),
            )
        

In [17]:
from eval_metrics import ssim
from eval_metrics.psnr import wpsnr
wssim = ssim.WSSIM(data_range=1.0)

In [18]:
eval_fn(gen_model, test_loader, wssim, wpsnr)

100%|██████████| 4/4 [00:03<00:00,  1.17it/s, psnr_mean=6.11, ssim_mean=0.0397, wpsnr_mean=5.81, wssim_mean=0.0397]
