In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
model_dir = '/kuacc/users/mali18/dicomogan/logs/_9132_seq_modulation_gaugan22022-10-24T00-36-21'
img_root = '/kuacc/users/abond19/datasets/aligned_fashion_dataset'
inverted_img_root =  '/kuacc/users/abond19/datasets/inverted_fashion_dataset'
inversion_root =  '/kuacc/users/abond19/datasets/w+_fashion_dataset/fashion/PTI/'

In [3]:
device = 'cuda'

In [4]:
from experiments_utils import *
model = load_model_from_dir(model_dir).to(device)

Using /scratch/users/mali18/.cache/torch_extensions/py37_cu113 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /scratch/users/mali18/.cache/torch_extensions/py37_cu113/fused/build.ninja...
Building extension module fused...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused...
Loading custom kernel...
Using /scratch/users/mali18/.cache/torch_extensions/py37_cu113 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /scratch/users/mali18/.cache/torch_extensions/py37_cu113/upfirdn2d/build.ninja...
Building extension module upfirdn2d...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module upfirdn2d...
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
Restored

In [5]:
# load images
import os
import torch 
from PIL import Image
import torchvision.transforms as transforms
IMG_EXTENSIONS = ['.png', '.PNG']
TXT_EXTENSIONS = ['.txt']

crop = None
size = None
trans_list = []
if crop is not None:
    trans_list.append(transforms.CenterCrop(tuple(crop)))
if size is not None:
    trans_list.append(transforms.Resize(tuple(size)))
trans_list.append(transforms.ToTensor())
img_transform=transforms.Compose(trans_list)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def is_text_file(filename):
    return any(filename.endswith(extension) for extension in TXT_EXTENSIONS)

def get_image(img_path):
    img = Image.open(img_path).convert('RGB')
    return img

def get_inversion(inversion_path):
    w_vector = torch.load(inversion_path, map_location='cpu')
    assert (w_vector.shape == (1, 18, 512)), "Inverted vector has incorrect shape"
    return w_vector

def load_video(vid_path):
    images, inversions, sampleT, inversion_imgs = [], [], [], [] 
    fname = vid_path
    for f in sorted(os.listdir(os.path.join(img_root, fname)))[5:15]:
        if is_image_file(f):
            imname = f[:-4]
            images.append(img_transform(get_image(os.path.join(img_root, fname, f))))
            inversion_imgs.append(img_transform(get_image(os.path.join(inverted_img_root, fname, f))))
            inversions.append(get_inversion(os.path.join(os.path.join(inversion_root, fname, imname + ".pt"))))
            sampleT.append(int(imname))
    
    return torch.stack(images).to(device), torch.cat(inversions, 0).to(device), torch.Tensor(sampleT).to(device), torch.stack(inversion_imgs).to(device)

In [6]:
import imageio
import torchvision
def save_gif(video, range, save_path):
    # Assuming that the current shape is T * B x C x H x W
    with imageio.get_writer(save_path, mode='I') as writer:
       for b_frames in video:
            # b_frames B x C x H x W
            frame = torchvision.utils.make_grid(b_frames,
                            nrow=b_frames.shape[0],
                            normalize=True,
                            range=range).detach().cpu().numpy()
            frame = (np.transpose(frame, (1, 2, 0)) * 255).astype(np.uint8)
            writer.append_data(frame)

#     wandb.log({name: wandb.Video(filename, fps=2, format="gif")})
#     os.remove(filename)


In [7]:
videos = [
    '8c110571',
    '2c111960',
    '8c110147'
]

# Reconstruct Video

In [8]:
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import torch.nn as nn
with torch.no_grad():
    for video in tqdm(videos):
        images, inversions, sampleT, inversion_imgs = load_video(video)
        save_dir = os.path.join('applications_results', f"{model_dir.split('/')[-1]}", video)
        os.makedirs(save_dir, exist_ok=True)
        # org
        save_gif(images, (0, 1), f'{save_dir}/original.gif')
        
        # inversion
        save_gif(inversion_imgs, (0, 1), f'{save_dir}/inversion.gif')
        
        tgt_desc = "A picture of a woman wearing a Women's Multi Color Regular fit Body with a Overlap V-neckline and Long sleevesA picture of a woman wearing a Women's Multi Color Regular fit Body with a Overlap V-neckline and Long sleeves"
        src_desc = "a photo of a woman wearing short sleeves t-shirt"
        txt_feat = model.clip_encode_text([tgt_desc])  # - model.clip_encode_text([src_desc])
        edited_videos = model(images.unsqueeze(0), 
                              sampleT,
                              inversions.mean(0, keepdims=True),
                              txt_feat)[0]
        save_gif(edited_videos, (-1, 1), f'{save_dir}/Mean_Frame.gif')
        to_PIL(model.stylegan_G(inversions.mean(0, keepdims=True))[0]).save(f'{save_dir}/mean_frame.png')
        
        
        edited_videos = model(images.unsqueeze(0), 
                              sampleT,
                              inversions[-1:],
                              txt_feat)[0]
        save_gif(edited_videos, (-1, 1), f'{save_dir}/Last_Frame.gif')
        to_PIL(model.stylegan_G(inversions[-1:])[0]).save(f'{save_dir}/last_frame.png')
        
        edited_videos = model(images.unsqueeze(0), 
                              sampleT,
                              inversions[0:1],
                              txt_feat)[0]
        save_gif(edited_videos, (-1, 1), f'{save_dir}/First_Frame.gif')
        to_PIL(model.stylegan_G(inversions[0:1])[0]).save(f'{save_dir}/first_frame.png')
        
        
        edited_videos = model(images.unsqueeze(0), 
                              sampleT,
                              inversions[5:6],
                              txt_feat)[0]
        save_gif(edited_videos, (-1, 1), f'{save_dir}/Middle_Frame.gif')
        to_PIL(model.stylegan_G(inversions[5:6])[0]).save(f'{save_dir}/middle_frame.png')
#         edited_videos = forward(model, images.unsqueeze(0), inversions.unsqueeze(0), sampleT, [des], flag='L')[0]
#         save_gif(edited_videos.permute(0, 2, 3, 1).detach().cpu().numpy(), f'{save_dir}/Last_frame')

#         edited_videos = forward(model, images.unsqueeze(0), inversions.unsqueeze(0), sampleT, [des], flag='F')[0]
#         save_gif(edited_videos.permute(0, 2, 3, 1).detach().cpu().numpy(), f'{save_dir}/First_frame')

        
        

  0%|          | 0/3 [00:00<?, ?it/s]

Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.


100%|██████████| 3/3 [00:35<00:00, 11.67s/it]


# Generate images for eval 

In [60]:
with open('data/fashion/fashion_train_videos.txt', 'r') as f:
    train_videos = f.read().split('\n')

In [61]:
with open('data/fashion/fashion_test_videos.txt', 'r') as f:
     test_videos = f.read().split('\n')

In [133]:
from tqdm import tqdm
def generate(lst, save_dir, desc, bs=4):
    with torch.no_grad():
        for i in tqdm(range(0, len(lst), bs)):
            images, inversions = [], [],
            for j in range(i, i+bs):
                a, b, c, _ = load_video(lst[j])
                images.append(a)
                inversions.append(b)
                sampleT = c
            
#             print(torch.stack(images, 0).shape, torch.stack(inversions, 0).shape)
            edited_videos = forward_triple(model, torch.stack(images, 0), torch.stack(inversions, 0), sampleT, [desc] * bs)          
            for video, video_name in zip(edited_videos, lst[i:i+bs]):
                save_path = os.path.join(save_dir, video_name)
                os.makedirs(save_path, exist_ok=True)
                for j, frame in enumerate(video):
                    to_PIL(frame).save(os.path.join(save_path, f"{j:06d}.png"))

In [134]:
train_save_dir = os.path.join("model_outputs", f"{model_dir.split('/')[-1]}", "train")

In [139]:
test_save_dir = os.path.join("model_outputs", f"{model_dir.split('/')[-1]}", "test")

In [136]:
desc = "A picture of a woman wearing a Women's Black Regular fit Blouse with a Round neckline and Long sleeves and a blue jean"

In [137]:
train_videos = np.array(train_videos)

In [140]:
generate(test_videos[:200], test_save_dir, desc)

100%|██████████| 50/50 [20:35<00:00, 24.72s/it]


In [96]:
Image.open('/kuacc/users/abond19/datasets/aligned_fashion_dataset/7c90376/00010.png').size

(512, 1024)