In [4]:
import os
import torch
import numpy as np

from augmentations.augmentations import aug_combined, aug_rotation

In [53]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

torch.autograd.set_grad_enabled(False)

from torchvision import transforms
from utils.resnet_custom import resnet50_baseline

def eval_transforms(pretrained=False):
    if pretrained:
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)

    trnsfrms_val = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
    )

    return trnsfrms_val

resnet = resnet50_baseline(pretrained=True)
resnet.to(device)
resnet.eval()

roi_transforms = eval_transforms(pretrained=True)

import augmentations.augmentations as A

In [54]:
# import generator
from generator import GeneratorMLP

# load model
dagan_run_code = "gan_mlp_s1_lr1e-03_None_b64_20221909_182052"
dagan_state_path = f"/home/guillaume/Documents/uda/project-augmented-embeddings/2-dagan/results/sicapv2/{dagan_run_code}/s_4_checkpoint.pt"
dagan_state_dict = torch.load(dagan_state_path)

n_tokens = 1024
dropout = 0.2
generator = GeneratorMLP(n_tokens, dropout)
generator.load_state_dict(dagan_state_dict["G_state_dict"])
generator.eval().to(device)
# print(generator)

GeneratorMLP(
  (encoder): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): ELU(alpha=1.0)
    (2): AlphaDropout(p=0.2, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ELU(alpha=1.0)
    (5): AlphaDropout(p=0.2, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): ELU(alpha=1.0)
    (8): AlphaDropout(p=0.2, inplace=False)
  )
  (decoder): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ELU(alpha=1.0)
    (2): AlphaDropout(p=0.2, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ELU(alpha=1.0)
    (5): AlphaDropout(p=0.2, inplace=False)
    (6): Linear(in_features=512, out_features=1024, bias=True)
  )
)

In [55]:
def get_true_emb(img):
    # print(img)
    # print(img.shape)
    
    aug_img = A.aug_rotation(img)
    # print(aug_img)

    aug_img = roi_transforms(aug_img).unsqueeze(0)
    aug_img = aug_img.to(device)
    true_emb = resnet(aug_img)
    return true_emb

def get_gen_emb(emb):
    with torch.no_grad():
        emb = emb.to(device)
        noise = torch.randn(emb.size(0), emb.size(1), requires_grad=False).to(device)
        aug_embs = generator.forward(emb, noise)
    
    return aug_embs

In [57]:
import time

img = np.random.randint(low=0, high=255, size=(256,256,3), dtype='uint8')
emb = torch.randn(1,1024)

# ------ TIME -----

def time_func(f):
    st = time.time()
    f()
    et = time.time()
    res = et - st
    return res

n = 10000
time_true = 0
time_gen = 0
for _ in range(n):
    # true
    st = time.time()
    get_true_emb(img)
    et = time.time()
    res = et - st
    time_true += res

    # gen
    st = time.time()
    get_gen_emb(emb)
    et = time.time()
    res = et - st
    time_gen += res
    
time_true /= n
time_gen /= n

print("true:        ", time_true)
print("gen (mlp):   ", time_gen)


true:         0.006424951219558716
gen (mlp):    0.0005146847009658813
