In [1]:
import os
os.cpu_count()

20

In [2]:
import torch
import numpy as np
from torch.utils.data import Dataset
from glob import glob
from skimage import io
import os
from torchvision import datasets, transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [4]:
from dataset.data_loaders import *
from dataset.utils.plot_utils import plot_s1s2_tensors, save_s1s2_tensors_plot

In [5]:
# transform = transforms.Compose([S2S1Normalize(),myToTensor()])

# print("Reading only S1 2021 train data...")
# s1s2_dataset = Sen12Dataset(s1_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_light\\2021\\s1_imgs\\train",
#                             s2_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_light\\2021\\s2_imgs\\train",
#                             s1_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_light\\2019\\s1_imgs\\train",
#                             s2_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_light\\2019\\s2_imgs\\train",
#                             transform=transform,
#                             two_way=False)
# print("len(s1s2_dataset): ",len(s1s2_dataset))
# print("s1s2_dataset[0][0]shape: ",s1s2_dataset[0][1].shape)

In [6]:
# save_s1s2_tensors_plot(s1s2_dataset[1], ["s2t2", "s1t2", "s2t1", "s1t1", "change map", "reversed change map"], 3,2,filename="test.png", fig_size=(8,10))

In [7]:
S2_INCHANNELS = 6
S1_INCHANNELS = 1
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
NUM_WORKERS = 8
IMAGE_SIZE = 256
WEIGHTED_LOSS = True
INPUT_CHANGE_MAP = False
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 2
LOAD_MODEL = False
SAVE_MODEL = False
SAVE_EVERY_EPOCH = 1
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

In [8]:
import torch

def save_some_examples(gen, val_dataset ,epoch, folder, cm_input, img_indx = 1):
    s2t2,s1t2,s2t1,s1t1,cm,rcm  = val_dataset[img_indx]
    s2t2,s1t2,s2t1,s1t1,cm,rcm = s2t2.to(DEVICE),s1t2.to(DEVICE),s2t1.to(DEVICE),s1t1.to(DEVICE),cm.to(DEVICE),rcm.to(DEVICE)
    if cm_input:
        s2t2 = torch.cat((s2t2, cm), dim=1)
        s1t1 = torch.cat((s1t1, rcm), dim=1)
    
    if os.path.exists(folder) == False:
        os.mkdir(f"{folder}/")
        
    gen.eval()
    with torch.no_grad():
        s1t2_generated = gen(s2t2.unsqueeze(0).to(torch.float32), s1t1.unsqueeze(0).to(torch.float32))
        s1t2_generated = s1t2_generated * 0.5 + 0.5  # remove normalization#
        
        save_s1s2_tensors_plot([s2t1,s1t1,s2t2,s1t2,cm,s1t2_generated[0]],
                               ["s2t1", "s1t1", "s2t2", "s1t2", "change map", "Generated s1t2"],
                               n_rows=3,
                               n_cols=2,
                               filename=f"{folder}//img_{img_indx}_epoc{epoch}.png",
                               fig_size=(8,10))
    gen.train()


def save_checkpoint(epoc,model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    filename =f"epoc{epoc}_" + filename  
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [9]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm


torch.backends.cudnn.benchmark = True


def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler, weighted_loss = WEIGHTED_LOSS, cm_input = INPUT_CHANGE_MAP):
    loop = tqdm(loader, leave=True)

    for idx, (s2t2,s1t2,s2t1,s1t1,cm,rcm) in enumerate(loop):
        s2t2,s1t2,s2t1,s1t1,cm,rcm = s2t2.to(DEVICE),s1t2.to(DEVICE),s2t1.to(DEVICE),s1t1.to(DEVICE),cm.to(DEVICE),rcm.to(DEVICE)
        if cm_input:
            s2t2 = torch.cat((s2t2, cm), dim=1)
            s1t1 = torch.cat((s1t1, rcm), dim=1)
        # Train Discriminator
        with torch.cuda.amp.autocast():
            s1t2_fake = gen(s2t2, s1t1)
            D_real = disc(s2t2, s1t1, s1t2)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(s2t2, s1t1, s1t2_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(s2t2, s1t1, s1t2_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            if weighted_loss:
                L1 = l1_loss(s1t2_fake, s1t2, cm, rcm) * L1_LAMBDA
            else:
                L1 = l1_loss(s1t2_fake, s1t2) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
                G_loss = G_loss.item(),
                L1 = L1.item(),
            )




In [10]:
from temporalgan.temporal_gan_v3_gen import Generator as GeneratorV3
from temporalgan.temporal_gan_v2_gen import Generator as GeneratorV2
from temporalgan.temporal_gan_v1_gen import Generator as GeneratorV1
from temporalgan.temporal_gan_v2_disc import Discriminator as DiscriminatorV2
from temporalgan.temporal_gan_v1_disc import Discriminator as DiscriminatorV1
from temporalgan.lossfunciton.loss_function import WeightedL1Loss

In [11]:
def main():
    disc = DiscriminatorV1(s2_in_channels=S2_INCHANNELS, s1_in_channels=S1_INCHANNELS).to(DEVICE)
    gen = GeneratorV2(s2_in_channels=S2_INCHANNELS, s1_in_channels= S1_INCHANNELS, features=64,pam_downsample=2).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    if WEIGHTED_LOSS:
        L1_LOSS = WeightedL1Loss(change_weight=5)
    else:
        L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    transform = transforms.Compose([S2S1Normalize(),myToTensor()])


    train_dataset = Sen12Dataset(s1_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s1_imgs\\test",
                                s2_t1_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2021\\s2_imgs\\test",
                                s1_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s1_imgs\\test",
                                s2_t2_dir="E:\\s1s2\\s1s2_patched_light\\s1s2_patched_extra_light\\2019\\s2_imgs\\test",
                                transform=transform,
                                two_way=False)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
#     val_dataset = MapDataset(root_dir=VAL_DIR)
#     val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(1, NUM_EPOCHS+1):
        print(f"Epoch: {epoch}")
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % SAVE_EVERY_EPOCH == 0 or True:
            #save_checkpoint(epoch,gen, opt_gen, filename=CHECKPOINT_GEN)
            #save_checkpoint(epoch,disc, opt_disc, filename=CHECKPOINT_DISC)
            save_some_examples(gen, train_dataset, epoch, folder="evaluation",cm_input=INPUT_CHANGE_MAP, img_indx=1)

#         save_some_examples(gen, val_loader, epoch, folder="evaluation")


In [12]:
import matplotlib
matplotlib.use('Agg')
main()

Epoch: 1


100%|██████████| 4/4 [00:08<00:00,  2.09s/it, D_fake=0.654, D_real=0.452, G_loss=454, L1=454]


Epoch: 2


100%|██████████| 4/4 [00:03<00:00,  1.02it/s, D_fake=0.421, D_real=0.508, G_loss=468, L1=467]
