In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
import fastai.vision.all as fv
import PIL
from pathlib import Path
import random
from Layers import *
from math import prod
from PerceptualLoss import perceptual_loss

In [None]:
class DumbNoiser(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        bs = x.shape[0]
        alpha = torch.rand(bs,1,1,1,device=x.device)
        noise = torch.randn_like(x)
        return x*(1 - alpha) + noise*alpha

In [None]:
class IterativeNoiser(nn.Module):
    def __init__(self, max_num_steps = 75, scale=0.1):
        super().__init__()
        self.num_steps = max_num_steps
        self.scale = scale

    def forward(self,x):
        bs = x.shape[0]
        
        num_iters = random.randint(0, self.num_steps)
        
        noise = torch.randn(num_iters, *x.shape, device=x.device)
        scale = torch.rand((1,bs,1,1,1),device=x.device)*self.scale
        #print(f"{num_iters}, {scale}")
        return x+(noise*scale).sum(dim=0)

In [None]:
class BlurNoiser(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        f = random.choice([2,4])
        ds = F.interpolate(x,scale_factor=1/f)
        return F.interpolate(ds,scale_factor=f)

In [None]:
class SelectCombinator(nn.Module):
    def __init__(self, modules):
        super().__init__()
        self.M = nn.ModuleList(modules)
       
    def forward(self,x):
        results = torch.stack([m(x) for m in self.M])
        n,bs,c,w,h = results.shape
        
        mask = torch.randint(0, 3, (bs,), device=x.device)
        return results[mask,torch.arange(bs, device=x.device)]

In [None]:
class ConvexCombinator(nn.Module):
    def __init__(self, modules):
        super().__init__()
        self.M = nn.ModuleList(modules)
    
    def forward(self, x):
        results = torch.stack([m(x) for m in self.M])
        n = results.shape[0]
        convex_conv = torch.softmax(torch.randn(n,1,1,1,1,device=x.device)*2,dim=0)
        return (convex_conv*results).sum(dim=0)

In [None]:
def normalize_tonto(x):
    return 2*x - 1
    
def desnormalize_tonto(x):
    return 0.5*(x + 1)

In [None]:
f = fv.PILImage.create("facesM/25_1_00266.jpg")
x = normalize_tonto(to_tensor(f)[None])
y = ConvexCombinator([IterativeNoiser(),DumbNoiser()])(x)
to_pil_image(torch.clamp(desnormalize_tonto(y[0]),0,1))

In [None]:
def num_params(m):
    return sum([prod(p.shape) for p in m.parameters()])

In [None]:
#x = torch.randn(2,3,16,16)

In [None]:
#U(x)

In [None]:
def get_age(f:Path):
    l = f.stem.split("_")
    return float(l[0])

def get_cat(f:Path):
    l = f.stem.split("_")
    return l[1]

In [None]:
edades = torch.tensor([get_age(f) for f in fv.get_image_files('facesM')])


In [None]:
edades_mean = edades.mean()
edades_std = edades.std()

In [None]:
class Unet(nn.Module):
    def __init__(self, encoder_blocks, decoder_blocks, fea_proc):
        super().__init__()
        self.EB = nn.ModuleList(encoder_blocks)
        self.DB = nn.ModuleList(decoder_blocks)
        self.fea_proc = fea_proc

    def forward(self, x, edad, cat):
        proc = self.fea_proc(edad, cat)
        x = torch.cat((x,proc),dim=1)
        resultados_parciales = [x]
        for e in self.EB:
            x = e(x)
            resultados_parciales.append(x)
        for d,rp in zip(self.DB,resultados_parciales[::-1]):
            bs,c,h,w = x.shape
            faltan = c-rp.shape[1]
            if faltan > 0: 
                rp = torch.cat([rp,torch.zeros(bs,faltan,h,w,device=x.device)],dim=1)
            x = d(x+rp)
        return x

In [None]:
class Noisy_Unet(nn.Module):
    def __init__ (self,encoder,Unet):
        super().__init__()
        self.encoder = encoder
        self.Unet = Unet
        self.Noiser = SelectCombinator([DumbNoiser(), IterativeNoiser()])
        self.PrewarNoiser = ConvexCombinator([DumbNoiser(), IterativeNoiser(30,0.08), BlurNoiser()])
        
    def forward(self, x, edad, cat):
        x = self.PrewarNoiser(x)
        return self.Unet(self.Noiser(self.encoder(x)), edad, cat)

In [None]:
class FeaturesProcessor(nn.Module):
    def __init__(self):
        super().__init__()
        self.Emb = nn.Embedding(num_embeddings=2,embedding_dim=3)
        self.Nn = nn.Sequential(
                    nn.Linear(4,64),
                    nn.LeakyReLU(),
                    nn.BatchNorm1d(64),
                    nn.Linear(64,2*4*4),
        )
    
    def forward(self,edad,cat):
        bs = edad.shape[0]
        edad = (edad - edades_mean)/edades_std
        cat = self.Emb(cat)
        a = torch.cat((edad[:,None],cat), dim = 1)
        b = self.Nn(a)
        b = b.reshape(bs,2,4,4)
        return F.interpolate(b,scale_factor=8,mode='bilinear')

In [None]:
encoder_blocks = [
    nn.Sequential(
        ResBlock(5,64),
        ResBlock(64),
        nn.BatchNorm2d(64),
        SelfAttention(64),
        ResBlock(64),
        ResBlock(64),
        nn.BatchNorm2d(64),
        *cab(64,128, s=2, k=2)
    ),
    nn.Sequential(
        ResBlock(128),
        ResBlock(128),
        nn.BatchNorm2d(128),
        SelfAttention(128),
        ResBlock(128),
        ResBlock(128),
        nn.BatchNorm2d(128),
        *cab(128,256, s=2, k=2)
    ),
    nn.Sequential(
        ResBlock(256),
        ResBlock(256,),
        nn.BatchNorm2d(256),
        SelfAttention(256),
        ResBlock(256,g=2),
        ResBlock(256),
        nn.BatchNorm2d(256),
        *cab(256,384, s=2, k=2)
    )
]

In [None]:
decoder_blocks = [
    nn.Sequential(
        ResBlock(384,g=2),
        ResBlock(384,g=2),
        nn.BatchNorm2d(384),
        SelfAttention(384),
        ResBlock(384,g=2),
        ResBlock(384,g=2),
        fv.PixelShuffle_ICNR(384,256)
    ),
    nn.Sequential(
        ResBlock(256),
        ResBlock(256),
        nn.BatchNorm2d(256),
        SelfAttention(256),
        ResBlock(256),
        ResBlock(256),
        fv.PixelShuffle_ICNR(256,128)
    ),
    nn.Sequential(
        ResBlock(128),
        ResBlock(128),
        nn.BatchNorm2d(128),
        SelfAttention(128),
        ResBlock(128),
        ResBlock(128),
        fv.PixelShuffle_ICNR(128,64)
    ),
    nn.Sequential(
        ResBlock(64),
        ResBlock(64),
        nn.BatchNorm2d(64),
        SelfAttention(64),
        ResBlock(64),
        ResBlock(64),
        conv2d(64,3)
    )
]

In [None]:
U = Unet(encoder_blocks, decoder_blocks, FeaturesProcessor())

In [None]:
num_params(U)

In [None]:
def load_data(folder, img_size, batch_size):
    tfms = fv.aug_transforms()
        
    data = fv.DataBlock(blocks = (fv.ImageBlock, fv.RegressionBlock, fv.CategoryBlock, fv.ImageBlock),
                        n_inp = 3,
                        get_items = fv.get_image_files,
                        getters   = [lambda x: x, get_age, get_cat, lambda x: x],
                        splitter  = fv.RandomSplitter(.05,seed = 666),
                        item_tfms = fv.Resize(img_size),
                        batch_tfms= tfms,
                     )
    return data.dataloaders(folder, bs=batch_size)

In [None]:
dls = load_data("facesM", 128, 128)

In [None]:
from autoencoder import create_autoencoder

In [None]:
A = create_autoencoder().eval()

In [None]:
autoencoder_dict = torch.load('models/Perceptual.pth')['model']

In [None]:
A.load_state_dict(autoencoder_dict)

In [None]:
for p in A.parameters():
    p.requires_grad_(False)

In [None]:
#x = torch.randn(1,3,32,32)
#A.eval().cpu()
#img_de_latente_aleatorio = A.decoder(x)

#to_pil_image(img_de_latente_aleatorio[0].clamp(0,1))

In [None]:
#f = fv.PILImage.create("facesM/30_1_28985.jpg")

In [None]:
#A.encoder(to_tensor(f)[None]).std()

In [None]:
A.cuda();

In [None]:
model = Noisy_Unet(A.encoder,U)

In [None]:
def mse_latente(yp, y):
    y = A.encoder(y)
    return F.smooth_l1_loss(yp, y)

In [None]:
learn = fv.Learner(dls,model,
                   loss_func = mse_latente,
                   opt_func = fv.ranger,
                   wd=0.05,
                   wd_bn_bias = True,
                   cbs=[fv.GradientClip(0.2), fv.SaveModelCallback(fname='UNET_it')]
                  )

In [None]:
#learn.load("UNET_perceptual_finished")

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(40,1e-2,div=0.95,pct_start=0.6)

In [None]:
learn.save("UNET_it_finished")

In [None]:
ploss = perceptual_loss().cuda()
def mse_latente_perceptual(yp_l, y):
    y_l = A.encoder(y)
    yp = A.decoder(yp_l)
    return 16*F.smooth_l1_loss(yp_l, y_l) + ploss(yp,y)

In [None]:
learn.loss_func = mse_latente_perceptual
learn.dls.train.bs = 16
learn.dls.valid.bs = 16

In [None]:
learn.fit_one_cycle(30,7e-4,div=0.95,pct_start=0.6)

In [None]:
learn.save("UNET_finished")

In [None]:
#learn.load("UNET_it_perceptual")

In [None]:
random.choice(alumnos)

In [None]:
from PIL import Image
from IPython.display import HTML, display
import io
import base64

def display_gif(image_list, duration=100, loop=0):
    buffer = io.BytesIO()

    image_list[0].save(buffer, format='GIF', save_all=True, append_images=image_list[1:], duration=duration, loop=loop)

    buffer.seek(0)
    gif_data = base64.b64encode(buffer.read()).decode('ascii')

    display(HTML(f'<img src="data:image/gif;base64,{gif_data}">'))

In [None]:
U.eval().cpu()
A.eval().cpu()
x = torch.randn(1,3,32,32)
edad = torch.tensor([4])
cat = torch.tensor([1])
images = [to_pil_image(A.decoder(x)[0])]
steps = 20
for i in range(steps):
    p = U(x,edad,cat)
    faltan = steps - i
    x = p/faltan + x*(faltan - 1)/faltan
    img = torch.clamp(A.decoder(x)[0],0,1)
    images.append(to_pil_image(img))

In [None]:
display_gif(images)

## Cambios que hice:

- Agregué capas de self attention a la UNET
- Creé el IterativeNoise, que en vez de agregar ruido una vez, lo hace varias veces (pero en cada una poquito ruido), porque así funciona el stable diffusion, no como lo habíamos hecho
- Combiné ambos ruidos. En cada imagen toma uno de los dos ruidos aleatoriamente (yo siento que tener más tipos de ruido lo hace mejor)
- Agregué los callbacks: SaveModel, GradientClipping, Weight Decay.
- Entrené 100 epochs. YOLO.