In [None]:
import torch.hub
import os
import sys
current = os.path.dirname(os.path.realpath("inversion-stylegan2.ipynb"))
parent = os.path.dirname(current)
sys.path.append(parent)


In [None]:
import torch
device = torch.device('cuda:3')
dino = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50').to(device)

In [None]:
from transformers import CLIPProcessor, CLIPModel
import torch
import torchvision
from torchvision.models import resnet50
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import clip
from PIL import Image
import requests
import torch.hub
import time
import pickle
import math

from match_utils import matching, stats, proggan, nethook, dataset, loading, plotting, layers, models

In [None]:
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)


gan, gan_layers = models.load_gan('stylegan2-lsun_horse', path='models/', device=device)    
dino, dino_layers = models.load_discr('dino', path='models/', device=device)

In [None]:
ganlayers, dinolayers = layers.get_layers(gan,gan_layers, dino, dino_layers,"stylegan2-lsun_horse", "dino", device)

In [None]:
table, gan_stats, dino_stats = loading.load_stats("/home/yossi_gandelsman/gan_matches/results/results_dino_resnet_stylegan2-lsun_horse", 
                                                  device)

### Best Buddies

In [None]:
match_scores,_ = torch.max(table,1)

In [None]:
gan_matches = torch.argmax(table,1)
dino_matches = torch.argmax(table,0)

In [None]:
perfect_matches = []
perfect_match_scores = []
dino_perfect_matches = []
num_perfect_matches = 0 
for i in range(table.shape[0]):
    gan_match = gan_matches[i].item()
    dino_match = dino_matches[gan_match].item()
    if dino_match == i:
        #print(i)
        num_perfect_matches+=1
        perfect_matches.append(i)
        dino_perfect_matches.append(gan_match)
        perfect_match_scores.append(match_scores[i])
        
print(num_perfect_matches)
print(num_perfect_matches/table.shape[0])

In [None]:
gan = nethook.InstrumentedModel(gan)
gan.retain_layers(gan_layers, detach = False)

dino = nethook.InstrumentedModel(dino)
dino.retain_layers(dino_layers)

In [None]:
for i, unit in enumerate(perfect_matches):
    perfect_matches[i] = layers.find_act(perfect_matches[i], ganlayers)#,all_gan_layers)

In [None]:
for i, unit in enumerate(dino_perfect_matches):
    dino_perfect_matches[i] = layers.find_act(dino_perfect_matches[i], dinolayers)#,all_dino_layers)

In [None]:
from scipy.stats import truncnorm
def truncate_noise(size, truncation):
    '''
    Function for creating truncated noise vectors: Given the dimensions (n_samples, z_dim)
    and truncation value, creates a tensor of that shape filled with random
    numbers from the truncated normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        truncation: the truncation value, a non-negative scalar
    '''
    
    truncated_noise = truncnorm.rvs(-1*truncation, truncation, size=size)
    
    return torch.Tensor(truncated_noise)

In [None]:
z1 = truncate_noise((1,512), 1).to(device)#
from torch.autograd import Variable

In [None]:
z = Variable(z1.clone(), requires_grad=True)

with torch.no_grad():
    mean_latent = gan.model.mean_latent(4096)


In [None]:
reg = torch.zeros((1,512)).to(device).detach()

In [None]:
def show_gan_im(gan_im):
    im = (gan_im+1)/2
    im = torch.permute(im[0],(1,2,0)).detach().cpu()
    plt.imshow(im)
    plt.show()
    #plt.imsave(im, "dog1.png")

In [None]:
img, _  = gan([z], 0.7, c)
show_gan_im(img)

In [None]:
#real_im = Image.open("/home/amildravid/bigGAN-DINO_swap/val_im/ILSVRC2012_val_00028617-_1_.jpg")
#real_im = Image.open("/home/amildravid/bigGAN-DINO_swap/golden_retriever/real/ILSVRC2012_val_00001112.jpg")
#real_im = Image.open("/home/amildravid/bigActivation_Matching/val_im/ILSVRC2012_val_00006981.jpg")
real_im = Image.open("/home/amildravid/activations_matching-main/activations_matching-main/misc/dogface_afhq.jpg")
#real_im = Image.open("/home/amildravid/bigGAN-DINO_swap/golden_retriever/sketch/sketch_7.jpg")
real_im

In [None]:
real_im = torchvision.transforms.ToTensor()(real_im).unsqueeze(0).to(device)
#real_im = torchvision.transforms.RandomResizedCrop(256)(real_im)
real_im = torch.nn.functional.interpolate(real_im, size = (512,512), mode = "bicubic")
dino_real_im = torch.nn.functional.interpolate(real_im, size = (256,256), mode = "bicubic")
dino_real_im = torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(dino_real_im)

In [None]:
plt.imshow(torch.permute(real_im[0], (1,2,0)).cpu())

In [None]:
dino(dino_real_im)
dino_activs =  matching.store_activs(dino, dino_layers)
#normalize
eps = 0.00001
for i,_ in enumerate(dino_activs):
    dino_activs[i] = (dino_activs[i]-dino_stats[i][0])/(dino_stats[i][1]+eps)

In [None]:
dino_perfect_activs = []
for idx in dino_perfect_matches:
    dino_perfect_activs.append(dino_activs[idx[0]][:,idx[1],:,:].unsqueeze(0))

### Pixel Loss

In [None]:
optim = torch.optim.Adam([z], lr=0.01, betas=(0.5, 0.999))  

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan(z,c)
    im = (sample+1)/2
    
    
    loss = torch.mean((im-real_im)**2)
    
    
    
    
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)  

# Activation Loss

In [None]:
optim = torch.optim.Adam([z], lr=0.01, betas=(0.5, 0.999))  

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan(z,c)
    
    
    
    gan_activs = matching.store_activs(gan, gan_layers)
    
    
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs):
        gan_activs[i] = (gan_activs[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)
        
    
    gan_perfect_activs = []
    for idx in perfect_matches:
        gan_perfect_activs.append(gan_activs[idx[0]][:,idx[1],:,:])
    
    
    loss = 0
    losses = []
    for i, _ in enumerate(gan_perfect_activs): 
        map_size = max((gan_perfect_activs[i].shape[1], dino_perfect_activs[i].shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs[i].unsqueeze(0))
        dino_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(dino_perfect_activs[i])   
        #loss += torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)/(map_size**2)
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(dino_activ_new**2)
        corr = prod/torch.sqrt(div1*div2)
        loss += corr
        losses.append(corr)
        #loss += torch.mean((gan_activ_new-dino_activ_new)**2)
        
        
        
    loss *= -1 
    #regularization = 50*torch.mean((z-reg)**2)
    #loss +=  regularization
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    im = (sample+1)/2
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)
    
    
    
    


      

### Both Losses

In [None]:
optim = torch.optim.Adam([z], lr=0.01, betas=(0.5, 0.999))  

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan(z,c)
    
    im = (sample+1)/2
    pixel_loss = torch.mean((im-real_im)**2)
    
    
    print(pixel_loss)
    
    
    gan_activs = matching.store_activs(gan, gan_layers)
    
    
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs):
        gan_activs[i] = (gan_activs[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)
        
    
    gan_perfect_activs = []
    for idx in perfect_matches:
        gan_perfect_activs.append(gan_activs[idx[0]][:,idx[1],:,:])
    
    
    loss = 0
    losses = []
    for i, _ in enumerate(gan_perfect_activs): 
        map_size = max((gan_perfect_activs[i].shape[1], dino_perfect_activs[i].shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs[i].unsqueeze(0))
        dino_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(dino_perfect_activs[i])   
        #loss += torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)/(map_size**2)
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(dino_activ_new**2)
        corr = prod/torch.sqrt(div1*div2)
        loss += corr
        losses.append(corr)
        #loss += torch.mean((gan_activ_new-dino_activ_new)**2)
        
        
        
    loss *= -1 
    print(loss)
    loss += 100*pixel_loss
    
    
    #regularization = 100*torch.mean((z-reg)**2)
    #loss +=  regularization
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    im = (sample+1)/2
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)
    
    
    
    


      

# W-Space

In [None]:
w1 = gan.model.mapping.w_avg.clone().unsqueeze(0)#torch.randn((1,512)).to(device)

In [None]:
w1.shape

In [None]:
w = Variable(w1.clone(), requires_grad=True)
optim = torch.optim.Adam([w], lr=0.01, betas=(0.5, 0.999))  

In [None]:
gan.model.num_ws

In [None]:
init_im = (gan.model.synthesis(w[0].repeat(1,16,1))+1)/2 
plt.imshow(torch.permute(init_im[0].detach().cpu(), (1,2,0)))

### Pixel Space

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan.model.synthesis(w[0].repeat(1,16,1))
    im = (sample+1)/2
    
    
    loss = torch.mean((im-real_im)**2)
    
    
    
    
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)  

### Activation Loss

In [None]:
w = Variable(w1.clone(), requires_grad=True)
optim = torch.optim.Adam([w], lr=0.01, betas=(0.5, 0.999))  

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan.model.synthesis(w[0].repeat(1,16,1))
    
    
    
    gan_activs = matching.store_activs(gan, gan_layers)
    
    
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs):
        gan_activs[i] = (gan_activs[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)
        
    
    gan_perfect_activs = []
    for idx in perfect_matches:
        gan_perfect_activs.append(gan_activs[idx[0]][:,idx[1],:,:])
    
    
    loss = 0
    losses = []
    for i, _ in enumerate(gan_perfect_activs): 
        map_size = max((gan_perfect_activs[i].shape[1], dino_perfect_activs[i].shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs[i].unsqueeze(0))
        dino_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(dino_perfect_activs[i])   
        #loss += torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)/(map_size**2)
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(dino_activ_new**2)
        corr = prod/torch.sqrt(div1*div2)
        loss += corr
        losses.append(corr)
        #loss += torch.mean((gan_activ_new-dino_activ_new)**2)
        
        
        
    loss *= -1 
    #regularization = 50*torch.mean((z-reg)**2)
    #loss +=  regularization
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    im = (sample+1)/2
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)
    
    
    
    


      

### Both Losses

In [None]:
w = Variable(w1.clone(), requires_grad=True)
optim = torch.optim.Adam([w], lr=0.01, betas=(0.5, 0.999))  

In [None]:
for epoch in range(0,1000):
    
    optim.zero_grad()
    sample = gan.model.synthesis(w[0].repeat(1,16,1))
    
    im = (sample+1)/2
    pixel_loss = torch.mean((im-real_im)**2)
    
    
    print(pixel_loss)
    
    
    gan_activs = matching.store_activs(gan, gan_layers)
    
    
    #normalize all activations
    eps = 0.00001
    for i,_ in enumerate(gan_activs):
        gan_activs[i] = (gan_activs[i]-gan_stats[i][0])/(gan_stats[i][1]+eps)
        
    
    gan_perfect_activs = []
    for idx in perfect_matches:
        gan_perfect_activs.append(gan_activs[idx[0]][:,idx[1],:,:])
    
    
    loss = 0
    losses = []
    for i, _ in enumerate(gan_perfect_activs): 
        map_size = max((gan_perfect_activs[i].shape[1], dino_perfect_activs[i].shape[1]))
        gan_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(gan_perfect_activs[i].unsqueeze(0))
        dino_activ_new = torch.nn.Upsample(size=(map_size,map_size), mode='bilinear')(dino_perfect_activs[i])   
        #loss += torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)/(map_size**2)
        prod = torch.einsum('aixy,ajxy->ij', gan_activ_new,dino_activ_new)
        div1 = torch.sum(gan_activ_new**2)
        div2 = torch.sum(dino_activ_new**2)
        corr = prod/torch.sqrt(div1*div2)
        loss += corr
        losses.append(corr)
        #loss += torch.mean((gan_activ_new-dino_activ_new)**2)
        
        
        
    loss *= -1 
    print(loss)
    loss += 100*pixel_loss
    
    
    #regularization = 100*torch.mean((z-reg)**2)
    #loss +=  regularization
    print("E:", epoch+1, "loss:", loss.item())
    loss.backward()
    optim.step()
    show_gan_im(sample)
    im = (sample+1)/2
    
    im = torch.permute(sample[0],(1,2,0)).detach().cpu().numpy()
    im = (im+1)/2
    
    
    if epoch<=9:
        file_name = "00"+str(epoch)+".png"
    elif epoch<=99:
        file_name = "0"+str(epoch)+".png"
    else: 
        file_name = str(epoch)+".png"
    
    #plt.imsave("/home/amildravid/bigGAN-DINO_swap/morph/ex1/im/"+file_name, im)
    
    
    
    


      

In [None]:
init_im = (gan.model.synthesis(w[0].repeat(1,16,1))+1)/2 
plt.imshow(torch.permute(init_im[0].detach().cpu(), (1,2,0)))

In [None]:
x = torch.zeros((1,512)).to(device)
x[:,0] = 1
w_new = w+5*x

In [None]:
init_im = (gan.model.synthesis(w_new[0].repeat(1,16,1))+1)/2 
plt.imshow(torch.permute(init_im[0].detach().cpu(), (1,2,0)))