In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms ,models
import cv2
import os 
import gc
import warnings
from models.models import ENet,UNet, TempDiscriminator3D
from data.datagen_stage3 import DataGenerator
import warnings
warnings.filterwarnings("ignore")
from IPython import display
import numpy as np

#params
EPOCHS = 10
starting_epoch = 0
device = 'cuda'
batch_size = 1
H,W = (128,128)
dataset_len = 8726

In [2]:
#models
encoder = models.video.mc3_18(weights='MC3_18_Weights.KINETICS400_V1')
encoder = nn.Sequential(*list(encoder.children())[:-1][:-1]).to(device).eval() #131072 output vector for 1,3,8,256,256
trainable_encoder_parameters = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
print(f'Total encoder Parameters {trainable_encoder_parameters}')

generator = ENet(in_channels=15, out_channels=3, residual_blocks=64).train().to(device)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-5,betas=(0.9, 0.9))


discriminator = TempDiscriminator3D(d=32)
discriminator.to(device).train()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-5,betas=(0.9, 0.9))

mobile = models.mobilenet_v3_small(weights=True)
mobile = mobile.features.eval().to(device)
trainable_mobile_parameters = sum(p.numel() for p in mobile.parameters() if p.requires_grad)
print(f'Total encoder Parameters {trainable_mobile_parameters}')


Total encoder Parameters 11490240
Total Trainable Parameters: 4855335
Total TempDiscriminator3D Trainable Parameters: 8369633
Total encoder Parameters 927008


In [3]:
#load weights
generator_ckpt_dir = './stage3/generator/'
discriminator_ckpt_dir = './stage3/discriminator/'
ckpts = [x for x in os.listdir(generator_ckpt_dir) if x.endswith('.pth')]
if ckpts:
    ckpts = sorted(ckpts, key = lambda x : x.split('.')[0].split('_')[1]) #sort
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(generator_ckpt_dir,latest))
    generator.load_state_dict(state_dict['model'])
    starting_epoch = state_dict['epoch'] + 1
    g_optimizer.load_state_dict(state_dict['optimizer'])
    print('loaded generator weights from previous session')

    disc_ckpts = [x for x in os.listdir(discriminator_ckpt_dir) if x.endswith('.pth')]
    disc_ckpts = sorted(disc_ckpts, key = lambda x : x.split('.')[0].split('_')[1]) #sort
    latest = disc_ckpts[-1]
    state_dict = torch.load(os.path.join(discriminator_ckpt_dir,latest))
    discriminator.load_state_dict(state_dict['model'])
    starting_epoch = state_dict['epoch'] + 1
    d_optimizer.load_state_dict(state_dict['optimizer'])
    print('loaded weights from previous session')
    print(f'starting from epoch {starting_epoch}')
else:
    ckpt_dir = './stage2/generator/'
    ckpts = [x for x in os.listdir(ckpt_dir) if x.endswith('.pth')]
    ckpts = sorted(ckpts, key = lambda x : int(x.split('.')[0].split('_')[1])) #sort
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(ckpt_dir,latest))
    generator.load_state_dict(state_dict['model'])
    print(f'loaded weights from second training stage {latest}')

loaded weights from second training stage generator_3.pth


In [4]:
#loss functions
def contextual_loss(x, y, h=0.5):
    """Computes contextual loss between x and y.
    Args:
      x: features of shape (N, C, H, W).
      y: features of shape (N, C, H, W).
    Returns:
      cx_loss = contextual loss between x and y (Eq (1) in the paper)
    """
    assert x.size() == y.size()
    N, C, H, W = x.size()   # e.g., 10 x 512 x 14 x 14. In this case, the number of points is 196 (14x14).
    y_mu = y.mean(3).mean(2).mean(0).reshape(1, -1, 1, 1)
    x_centered = x - y_mu
    y_centered = y - y_mu
    x_normalized = x_centered / torch.norm(x_centered, p=2, dim=1, keepdim=True)
    y_normalized = y_centered / torch.norm(y_centered, p=2, dim=1, keepdim=True)
    # The equation at the bottom of page 6 in the paper
    # Vectorized computation of cosine similarity for each pair of x_i and y_j
    x_normalized = x_normalized.reshape(N, C, -1)                                # (N, C, H*W)
    y_normalized = y_normalized.reshape(N, C, -1)                                # (N, C, H*W)
    cosine_sim = torch.bmm(x_normalized.transpose(1, 2), y_normalized)           # (N, H*W, H*W)
    d = 1 - cosine_sim                                  # (N, H*W, H*W)  d[n, i, j] means d_ij for n-th data 
    d_min, _ = torch.min(d, dim=2, keepdim=True)        # (N, H*W, 1)
    # Eq (2)
    d_tilde = d / (d_min + 1e-5)
    # Eq(3)
    w = torch.exp((1 - d_tilde) / h)
    # Eq(4)
    cx_ij = w / torch.sum(w, dim=2, keepdim=True)       # (N, H*W, H*W)
    # Eq (1)
    cx = torch.mean(torch.max(cx_ij, dim=1)[0], dim=1)  # (N, )
    cx_loss = torch.mean(-torch.log(cx + 1e-5))
    return cx_loss
def perceptual_loss(x, y):
    x_norm = x / torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
    y_norm = y / torch.sqrt(torch.sum(y**2, dim=1, keepdim=True))
    return torch.sqrt(torch.sum((x_norm - y_norm)**2)) ** 2 / x.numel()
def contrastive_motion_loss(stable, unstable, generated):
    stable = preprocess(stable)
    unstable = preprocess(unstable)
    generated = preprocess(generated)
    A = encoder(stable).view(batch_size,-1)
    A = A / torch.sqrt(torch.sum(A**2, dim=1, keepdim=True))
    P = encoder(generated).view(batch_size,-1)
    P = P / torch.sqrt(torch.sum(P**2, dim=1, keepdim=True))
    N = encoder(unstable).view(batch_size,-1)
    N = N / torch.sqrt(torch.sum(N**2, dim=1, keepdim=True))

    d1 = torch.mean(torch.sqrt(torch.sum(torch.pow(A - P,2),dim =1)),dim = 0,keepdim=True).to(device) #euclidean distance of vectors
    d2 = torch.mean(torch.sqrt(torch.sum(torch.pow(A - N,2),dim =1)),dim = 0,keepdim=True).to(device)#euclidean distance of vectors
    return torch.max(d1 - d2 + 1, 0).values

binary_cross_entropy = nn.BCELoss()

def preprocess(tensor, resize_shape=(128, 171), crop_shape=(112, 112),
                        mean=[0.155, 0.161, 0.153], std=[0.228, 0.231, 0.226]):
    """
    Apply transforms to each 3D slice of the 4D tensor along the time dimension.

    Args:
        tensor (torch.Tensor): 4D tensor of shape [B, C, T, H, W].
        resize_shape (tuple): The target size for resizing each 3D slice (H, W).
        crop_shape (tuple): The target size for center cropping each 3D slice (H, W).
        mean (list): List of mean values for normalization.
        std (list): List of standard deviation values for normalization.

    Returns:
        torch.Tensor: Transformed 4D tensor of shape [B, C, T, H, W].
    """
    transforms_3d = transforms.Compose([
        transforms.Resize(resize_shape, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(crop_shape),
        transforms.Normalize(mean=mean, std=std),
    ])

    transformed_slices = []
    for t in range(tensor.size(2)):
        transformed_slice = transforms_3d(tensor[:, :, t])
        transformed_slices.append(transformed_slice.unsqueeze(2))  # Add the time dimension back
    return torch.cat(transformed_slices, dim=2)

In [9]:
data_gen = DataGenerator((H,W,3), txt_path='./trainlist_stage3.txt',skip=2)
train_ds = iter(data_gen())

In [10]:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('./runs/stage3/')

In [11]:
cv2.namedWindow('window', cv2.WINDOW_NORMAL)
g_running_loss = 0
d_running_loss = 0
for epoch in range(starting_epoch, EPOCHS):
    for idx,batch in enumerate(train_ds):
        torch.cuda.empty_cache()
        input_sequence, unstable_sequence, stable_sequence = batch
        generated_sequence = torch.zeros(1,3,8,H,W).float()
        g_loss = 0
        for k in range(8):
            x = input_sequence[k].to(device)
            y = stable_sequence[:,:,k,:,:].to(device)
            y_hat = generator(x)
            generated_sequence[:,:,k,:,:] = y_hat.cpu()
            # compute image losses
            #get embeddings
            feat1 = mobile(y_hat)
            feat2 = mobile(y)
            percept_loss = perceptual_loss(feat1,feat2)
            context_loss = contextual_loss(feat1, feat2)
            g_loss += percept_loss + context_loss
    
        # Update temporal discriminator
        d_optimizer.zero_grad()
        fake_prediction = discriminator(generated_sequence.detach().to(device))
        fake_labels = torch.zeros_like(fake_prediction)
        real_prediction = discriminator(stable_sequence.to(device))
        real_labels = torch.ones_like(real_prediction)
        predictions = torch.cat([fake_prediction, real_prediction], dim=0)
        labels = torch.cat([fake_labels, real_labels], dim=0)
        d_loss = binary_cross_entropy(predictions, labels)
        d_loss.backward()
        d_optimizer.step()
        
        #Update Generator
        generated_sequence = generated_sequence.to(device)
        unstable_sequence = unstable_sequence.to(device)
        stable_sequence = stable_sequence.to(device)
        #with torch.no_grad():
        score = discriminator(generated_sequence)
        labels = torch.ones_like(score).to(device)
        generator_adv_loss = binary_cross_entropy(score, labels)
        # Contrastive motion loss
        encoder.to(device)
        cml = contrastive_motion_loss(stable_sequence,
                                        unstable_sequence,
                                        generated_sequence)
        g_loss += generator_adv_loss + 10 * cml 
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        g_running_loss += g_loss.item()
        d_running_loss += d_loss.item()
        print(f'\repoch: {epoch},batch: {idx}, generator_loss:{g_running_loss / (idx % 1000 + 1)} ,per: {percept_loss.item()}, cx: {context_loss.item()},\
                  cml: {cml.item()},adv: {generator_adv_loss.item()} , discriminator_loss:{d_loss.item():.3f}',end = '')
        del feat1, feat2, percept_loss, context_loss
        gc.collect()
        
        #visualization
        means = np.array([0.155,0.161,0.153],dtype = np.float32)
        stds = np.array([0.22,0.231,0.226],dtype = np.float32)
        img = generated_sequence[:,:,0,:,:].permute(0,2,3,1)[0,...].cpu().detach().numpy()
        img *= stds
        img += means
        img = np.clip(img * 255.0,0,255).astype(np.uint8)
        img1 = stable_sequence[:,:,0,:,:].permute(0,2,3,1)[0,...].cpu().numpy()
        img1 *= stds
        img1 += means
        img1 = np.clip(img1 * 255.0,0,255).astype(np.uint8)
        concat = cv2.hconcat([img,img1])
        cv2.imshow('window',concat)
        if cv2.waitKey(1) & 0xFF == ord('9'):
            break
        if idx % 1000 == 999:
            writer.add_scalar('generator_loss',
                                g_running_loss / 1000,
                                epoch * dataset_len + idx)
            writer.add_scalar('discriminator_loss',
                                d_running_loss / 1000,
                                epoch * dataset_len + idx)
            g_running_loss = 0.0
            d_running_loss = 0.0
            model_path = os.path.join('E:/ModelCkpts/GAN2/stage3/generator/',f'generator_{epoch}.pth')
            torch.save({'model':generator.state_dict(),
                        'optimizer' : g_optimizer.state_dict(),
                        'epoch' : epoch}
                    ,model_path)
            model_path = os.path.join('E:/ModelCkpts/GAN2/stage3/discriminator/',f'discriminator_{epoch}.pth')
            torch.save({'model': discriminator.state_dict(),
                        'optimizer' : d_optimizer.state_dict(),
                        'epoch' : epoch}
                    ,model_path)

cv2.destroyAllWindows()


epoch: 0,batch: 0, generator_loss:16.995271682739258 ,per: 0.0008255462162196636, cx: 0.7793654203414917,                  cml: 1.034876823425293,adv: 0.4183591604232788 , discriminator_loss:0.771