In [1]:
%load_ext autoreload
%autoreload 2

In [181]:
model_dir = '/kuacc/users/mali18/dicomogan/logs/_from_mean_w_real_imgs_rec2022-10-04T11-42-09'
img_root = '/kuacc/users/abond19/datasets/aligned_fashion_dataset'
inverted_img_root =  '/kuacc/users/abond19/datasets/inverted_fashion_dataset'
inversion_root =  '/kuacc/users/abond19/datasets/w+_fashion_dataset/fashion/PTI/'

In [182]:
device = 'cuda'

In [183]:
tgt_desc = [
    'Blue T-shirts Made from rayon-nylon blend Round neckline Slogan printed front Pearl detail Short sleeves Regular fit',
    'Dresses Shift White One shoulder design Made from cotton Lace detailed ruffled yoke Regular fit Mini length Sleeveless',
    'Multi Leggings Made from poly-lycra blend Elasticated waistband All over crane print Tape detail at sides Bodycon fit'
]

In [184]:
videos = [
    '8c110571',
    '2c111960',
    '8c110147'
]

In [185]:
from experiments_utils import *
model = load_model_from_dir(model_dir)

Restored from /kuacc/users/mali18/dicomogan/logs/_from_mean_w_real_imgs_rec2022-10-04T11-42-09/VideoManipulation/_from_mean_w_real_imgs_rec2022-10-04T11-42-09/checkpoints/epoch=201-step=75547.ckpt


In [186]:
model = model.to(device)

In [187]:
def save_gif(video, save_path):
    # Assuming that the current shape is T x C x H x W
    import imageio
    with imageio.get_writer(f'{save_path}.gif', mode='I') as writer:
        for img in video:
            writer.append_data(img)

In [188]:
# load images
import os
import torch 
from PIL import Image
import torchvision.transforms as transforms
IMG_EXTENSIONS = ['.png', '.PNG']
TXT_EXTENSIONS = ['.txt']

crop = None
size = (256, 192)
trans_list = []
if crop is not None:
    trans_list.append(transforms.CenterCrop(tuple(crop)))
if size is not None:
    trans_list.append(transforms.Resize(tuple(size)))
trans_list.append(transforms.ToTensor())
img_transform=transforms.Compose(trans_list)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def is_text_file(filename):
    return any(filename.endswith(extension) for extension in TXT_EXTENSIONS)

def get_image(img_path):
    img = Image.open(img_path).convert('RGB')
    return img

def get_inversion(inversion_path):
    w_vector = torch.load(inversion_path, map_location='cpu')
    assert (w_vector.shape == (1, 18, 512)), "Inverted vector has incorrect shape"
    return w_vector

def load_video(vid_path):
    images, inversions, sampleT, inversion_imgs = [], [], [], [] 
    fname = vid_path
    for f in sorted(os.listdir(os.path.join(img_root, fname)))[5:25]:
        if is_image_file(f):
            imname = f[:-4]
            images.append(img_transform(get_image(os.path.join(img_root, fname, f))))
            inversion_imgs.append(img_transform(get_image(os.path.join(inverted_img_root, fname, f))))
            inversions.append(get_inversion(os.path.join(os.path.join(inversion_root, fname, imname + ".pt"))))
            sampleT.append(int(imname))
    
    return torch.stack(images).to(device), torch.cat(inversions, 0).to(device), torch.Tensor(sampleT).to(device), torch.stack(inversion_imgs).to(device)

In [189]:
from torch import nn
def forward_w_color(self, videos, inversions, sampleT, input_desc):
    """
    return a dictionary of tensors in the range [-1, 1]
    """
    ret = dict()

    vid = videos # B x T x ch x H x W -- range [0, 1]
    input_desc = input_desc 
    sampleT = sampleT 
    
    bs, T, ch, height, width = vid.size()
    n_frames = T
    ts = (sampleT)*0.01
    ts = ts - ts[0] 


    video_sample = vid # B x T x C x H x W 
    video_sample = video_sample.permute(1,0,2,3,4) # T x B x C x H x W 
    video_sample = video_sample.contiguous().view(n_frames * bs, ch, height, width) # T*B x C x H x W 
    video_sample_norm = video_sample * 2 - 1 # range [-1, 1] to pass to the generator and disc

    # inversions reshape
    inversions_bf = inversions # B, T x n_layers x D
    bs, T, n_channels, dim = inversions_bf.shape
    inversions_tf = inversions_bf.permute(1, 0, 2, 3)
    inversions = inversions_tf.contiguous().reshape(T * bs, n_channels, dim) # T * B x n_layers x D

    # downsample res for vae
    vid_rs_full = nn.functional.interpolate(video_sample, scale_factor=0.5, mode="bicubic", align_corners=False, recompute_scale_factor=True)
    vid_rs = vid_rs_full.view(n_frames, bs, ch, int(height*0.5),int(width*0.5) )
    vid_rs = vid_rs.permute(1,0,2,3,4) #  B x T x C x H//2 x W//2

    # encode text
    txt_feat = self.clip_encode_text(input_desc)  # B x D
    txt_feat = txt_feat.unsqueeze(0).repeat(n_frames,1,1)
    txt_feat = txt_feat.view(bs * n_frames, -1)  # T*B x D

    # vae encode frames
    zs, zd, mu_logvar_s, mu_logvar_d = self.bVAE_enc(vid_rs, ts)
    z_vid = torch.cat((zs, zd), 1) # T*B x D 

    muT, logvarT = self.text_enc(txt_feat)
    zT = self.reparametrize(muT, logvarT) # T*B x D 

    # generate with mathching text
    latentw = self.mapping(z_vid[:,self.vae_cond_dim:])
    
    frame_rep = torch.cat((latentw, txt_feat), -1) # T*B x D1+D2

    # predict latents delta
    w_latents = inversions + self.delta_inversion_weight * self.style_mapper(inversions, frame_rep)
    ret = self.stylegan_G(w_latents) / 2 + 0.5 
    ret = ret.reshape(T, bs, ret.shape[1], ret.shape[2], ret.shape[3]).permute(1, 0, 2, 3, 4)
    return ret



In [190]:
from torch import nn
def forward_from_mean(self, videos, inversions, sampleT, input_desc):
    """
    return a dictionary of tensors in the range [-1, 1]
    """
    ret = dict()

    vid = videos # B x T x ch x H x W -- range [0, 1]
    input_desc = input_desc 
    sampleT = sampleT 
    
    bs, T, ch, height, width = vid.size()
    n_frames = T
    ts = (sampleT)*0.01
    ts = ts - ts[0] 


    video_sample = vid # B x T x C x H x W 
    video_sample = video_sample.permute(1,0,2,3,4) # T x B x C x H x W 
    video_sample = video_sample.contiguous().view(n_frames * bs, ch, height, width) # T*B x C x H x W 
    video_sample_norm = video_sample * 2 - 1 # range [-1, 1] to pass to the generator and disc

    # inversions reshape
    inversions_bf = inversions # B, T x n_layers x D
    bs, T, n_channels, dim = inversions_bf.shape
    inversions_tf = inversions_bf.permute(1, 0, 2, 3)
    inversions = inversions_tf.contiguous().reshape(T * bs, n_channels, dim) # T * B x n_layers x D

    # downsample res for vae
    vid_rs_full = nn.functional.interpolate(video_sample, scale_factor=0.5, mode="bicubic", align_corners=False, recompute_scale_factor=True)
    vid_rs = vid_rs_full.view(n_frames, bs, ch, int(height*0.5),int(width*0.5) )
    vid_rs = vid_rs.permute(1,0,2,3,4) #  B x T x C x H//2 x W//2

    # encode text
    txt_feat = self.clip_encode_text(input_desc)  # B x D
    txt_feat = txt_feat.unsqueeze(0).repeat(n_frames,1,1)
    txt_feat = txt_feat.view(bs * n_frames, -1)  # T*B x D

    # vae encode frames
    zs, zd, mu_logvar_s, mu_logvar_d = self.bVAE_enc(vid_rs, ts)
    z_vid = torch.cat((zs, zd), 1) # T*B x D 

    muT, logvarT = self.text_enc(txt_feat)
    zT = self.reparametrize(muT, logvarT) # T*B x D 

    # generate with mathching text
    latentw = self.mapping(z_vid[:,self.vae_cond_dim:])

    frame_rep = torch.cat((latentw, txt_feat), -1) # T*B x D1+D2

    # predict latents delta
    src_inversion = inversions_tf.mean(0, keepdims=True) # 1 x B x 18 x 512
    src_inversion_tf = src_inversion.repeat(T, 1, 1, 1)
    src_inversion = src_inversion_tf.reshape(T*bs, n_channels, dim)
    w_latents = src_inversion + self.delta_inversion_weight * self.style_mapper(src_inversion, frame_rep)
    ret = self.stylegan_G(w_latents) / 2 + 0.5 
    ret = ret.reshape(T, bs, ret.shape[1], ret.shape[2], ret.shape[3]).permute(1, 0, 2, 3, 4)
    return ret



In [191]:
from torch import nn
def tmp_forward(self, videos, inversions, sampleT, input_desc):
    """
    return a dictionary of tensors in the range [-1, 1]
    """
    ret = dict()

    vid = videos # B x T x ch x H x W -- range [0, 1]
    input_desc = input_desc 
    sampleT = sampleT 
    
    bs, T, ch, height, width = vid.size()
    n_frames = T
    ts = (sampleT)*0.01
    ts = ts - ts[0] 


    video_sample = vid # B x T x C x H x W 
    video_sample = video_sample.permute(1,0,2,3,4) # T x B x C x H x W 
    video_sample = video_sample.contiguous().view(n_frames * bs, ch, height, width) # T*B x C x H x W 
    video_sample_norm = video_sample * 2 - 1 # range [-1, 1] to pass to the generator and disc

    # inversions reshape
    inversions_bf = inversions # B, T x n_layers x D
    bs, T, n_channels, dim = inversions_bf.shape
    inversions_tf = inversions_bf.permute(1, 0, 2, 3)
    inversions = inversions_tf.contiguous().reshape(T * bs, n_channels, dim) # T * B x n_layers x D

    # downsample res for vae
    vid_rs_full = nn.functional.interpolate(video_sample, scale_factor=0.5, mode="bicubic", align_corners=False, recompute_scale_factor=True)
    vid_rs = vid_rs_full.view(n_frames, bs, ch, int(height*0.5),int(width*0.5) )
    vid_rs = vid_rs.permute(1,0,2,3,4) #  B x T x C x H//2 x W//2

    # encode text
    txt_feat = self.clip_encode_text(input_desc)  # B x D
    txt_feat = txt_feat.unsqueeze(0).repeat(n_frames,1,1)
    txt_feat = txt_feat.view(bs * n_frames, -1)  # T*B x D

    # vae encode frames
    zs, zd, mu_logvar_s, mu_logvar_d = self.bVAE_enc(vid_rs, ts)
    z_vid = torch.cat((zs, zd), 1) # T*B x D 

    muT, logvarT = self.text_enc(txt_feat)
    zT = self.reparametrize(muT, logvarT) # T*B x D 

    # generate with mathching text
    latentw = self.mapping(z_vid[:,self.vae_cond_dim:])

    frame_rep = (latentw, txt_feat) # T*B x D1+D2

    # predict latents delta
    src_inversion = inversions_tf.mean(0, keepdims=True) # 1 x B x 18 x 512
    src_inversion_tf = src_inversion.repeat(T, 1, 1, 1)
    src_inversion = src_inversion_tf.reshape(T*bs, n_channels, dim)
    w_latents = src_inversion + self.delta_inversion_weight * self.style_mapper(src_inversion, *frame_rep)
    ret = self.stylegan_G(w_latents) / 2 + 0.5 
    ret = ret.reshape(T, bs, ret.shape[1], ret.shape[2], ret.shape[3]).permute(1, 0, 2, 3, 4)
    return ret


In [192]:
def forward_hairclip(self, videos, inversions, sampleT, input_desc):
    """
    return a dictionary of tensors in the range [-1, 1]
    """
    ret = dict()

    vid = videos # B x T x ch x H x W -- range [0, 1]
    input_desc = input_desc 
    sampleT = sampleT 

    bs, T, ch, height, width = vid.size()
    n_frames = T
    ts = (sampleT)*0.01
    ts = ts - ts[0] 


    video_sample = vid # B x T x C x H x W 
    video_sample = video_sample.permute(1,0,2,3,4) # T x B x C x H x W 
    video_sample = video_sample.contiguous().view(n_frames * bs, ch, height, width) # T*B x C x H x W 
    video_sample_norm = video_sample * 2 - 1 # range [-1, 1] to pass to the generator and disc

    # inversions reshape
    inversions_bf = inversions # B, T x n_layers x D
    bs, T, n_channels, dim = inversions_bf.shape
    inversions_tf = inversions_bf.permute(1, 0, 2, 3)
    inversions = inversions_tf.contiguous().reshape(T * bs, n_channels, dim) # T * B x n_layers x D

    txt_feat = self.get_text_embedding(input_desc) # B x D
    txt_feat = txt_feat.unsqueeze(0).repeat(n_frames, 1, 1) # T x B x D
    txt_feat = txt_feat.view(bs * n_frames, -1) # T * B x D

    adjusted_latent = inversions + self.delta_inversion_weight * self.mapping_network(inversions, txt_feat)

    ret = self.G(adjusted_latent) / 2 + 0.5 
    ret = ret.reshape(T, bs, ret.shape[1], ret.shape[2], ret.shape[3]).permute(1, 0, 2, 3, 4)
    return ret


In [193]:
def forward_hairclip_same_dir(self, videos, inversions, sampleT, input_desc):
    """
    return a dictionary of tensors in the range [-1, 1]
    """
    ret = dict()

    vid = videos # B x T x ch x H x W -- range [0, 1]
    input_desc = input_desc 
    sampleT = sampleT 

    bs, T, ch, height, width = vid.size()
    n_frames = T
    ts = (sampleT)*0.01
    ts = ts - ts[0] 


    video_sample = vid # B x T x C x H x W 
    video_sample = video_sample.permute(1,0,2,3,4) # T x B x C x H x W 
    video_sample = video_sample.contiguous().view(n_frames * bs, ch, height, width) # T*B x C x H x W 
    video_sample_norm = video_sample * 2 - 1 # range [-1, 1] to pass to the generator and disc

    # inversions reshape
    inversions_bf = inversions # B, T x n_layers x D
    bs, T, n_channels, dim = inversions_bf.shape
    inversions_tf = inversions_bf.permute(1, 0, 2, 3)
    inversions = inversions_tf.contiguous().reshape(T * bs, n_channels, dim) # T * B x n_layers x D

    txt_feat = self.get_text_embedding(input_desc) # B x D
    txt_feat = txt_feat.unsqueeze(0).repeat(n_frames, 1, 1) # T x B x D
    txt_feat = txt_feat.view(bs * n_frames, -1) # T * B x D
    
    deltas = self.delta_inversion_weight * self.mapping_network(inversions, txt_feat) # T * B x n_layers x D
    deltas = deltas.reshape(T, bs, n_channels, dim)[0:1]
    deltas = deltas.repeat(T, 1, 1, 1).reshape(T * bs, n_channels, dim)
    adjusted_latent = inversions + deltas

    ret = self.G(adjusted_latent) / 2 + 0.5 
    ret = ret.reshape(T, bs, ret.shape[1], ret.shape[2], ret.shape[3]).permute(1, 0, 2, 3, 4)
    return ret

In [194]:
from tqdm import tqdm
with torch.no_grad():
    for video in tqdm(videos):
        images, inversions, sampleT, inversion_imgs = load_video(video)
        
        save_dir = os.path.join('results', model_dir.split('/')[-1], video)
        os.makedirs(save_dir, exist_ok=True)
        # org
        save_gif(images.permute(0, 2, 3, 1).detach().cpu().numpy(), f'{save_dir}/original')
        
        # inversion
        save_gif(inversion_imgs.permute(0, 2, 3, 1).detach().cpu().numpy(), f'{save_dir}/inversion')
        
        ln_desc = len(tgt_desc)
        images = images.unsqueeze(0).repeat(ln_desc, 1, 1, 1, 1)
        inversions = inversions.unsqueeze(0).repeat(ln_desc, 1, 1, 1)
        edited_videos = forward_from_mean(model, images, inversions, sampleT, tgt_desc)
        
        for j, ed_video in enumerate(edited_videos):
            save_gif(ed_video.permute(0, 2, 3, 1).detach().cpu().numpy(), f'{save_dir}/desc_{j}')
        
        











100%|██████████| 3/3 [00:35<00:00, 11.92s/it]
