In [None]:
import torchvision.transforms as T
import torch,torchvision
from torch import nn,optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
# import torchvision.transforms.functional as VF

In [None]:
batch_size = 64  * 2
cifar10 = DataLoader(
    torchvision.datasets.CIFAR10(
        "./datasets",
        download=True,
        transform=T.Compose([
            T.ToTensor(),
            T.Normalize((.5,.5,.5),(.5,.5,.5))
        ])
    ),
    batch_size=batch_size,
    num_workers=4,
    pin_memory=(device=="cuda"),
    shuffle=True,
)
mnist = DataLoader(
    torchvision.datasets.MNIST(
        "./datasets",
        download=True,
        transform=T.Compose([
            T.ToTensor(),
            T.Normalize((.5),(.5))
        ])
    ),
    batch_size=batch_size,
    num_workers=4,
    pin_memory=(device=="cuda"),
    shuffle=True,
)
len(cifar10)

### Models Forward+Loss

In [None]:
def gen_tick(dis,g):
    v = dis.forward(g)
    l = F.binary_cross_entropy(v, torch.ones(v.shape).to(device))
    return l
def disc_tick(dis,x, g):
    v1 = dis.forward(x)
    v2 = dis.forward(g)
    l1 = F.binary_cross_entropy(v1, torch.ones(v1.shape).to(device))
    l2 = F.binary_cross_entropy(v2, torch.zeros(v2.shape).to(device))
    return (l1+l2)/2

### Helpers

In [None]:
import tqdm
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from dataclasses import dataclass
@dataclass
class Case:
    gen:nn.Module
    dis:nn.Module
    mean:float
    std:float
    gen_optim:torch.optim.Optimizer
    dis_optim:torch.optim.Optimizer
    gen_input_size:tuple
    gen_imput_dim:int
def Conv(inc,outc,ks,act=nn.ReLU(inplace=True),transposed=False,norm=True,**kwargs):
    _Conv = nn.ConvTranspose2d if transposed else nn.Conv2d
    layers = [_Conv(inc, outc, ks, bias=False, **kwargs),]
    if norm:layers += [nn.BatchNorm2d(outc)]
    if act:layers += [act]
    return nn.Sequential(*layers)


def setup(version, input_dim, input_size,*args):
    gen, dis, gen_optim,dis_optim, m, std = version(input_dim,*args)
    return Case(
        gen=gen,
        dis=dis,
        mean=m,
        std=std,
        gen_optim=gen_optim,
        dis_optim=dis_optim,
        gen_imput_dim=input_dim,
        gen_input_size=input_size,
    )
def grid(images,nrows):
    return torchvision.utils.make_grid(images.detach().cpu(),nrow=nrows,value_range=[-1,1])
def test(gen,noise,res):
    gen.eval()
    with torch.inference_mode():
        images = gen.forward(noise)
    images = grid(images, noise.size(0))
    res.append(images)
def sample(size,generator=None,m=0.0,std=1.0):
    return torch.normal(m,std, size=size, device=device,generator=generator)
def show(imgs,**kwargs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=1,nrows=len(imgs), squeeze=True,**kwargs)
    axs = np.atleast_1d(axs)
    for img,ax in zip(imgs,axs):
        import torchvision.transforms.functional as F
        img = F.to_pil_image(img)
        ax.imshow(np.asarray(img))
        ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
def train(
    case:Case,
    dataloader,
    num_epochs,
    test_noise=None,
    test_every_n_epochs=0,
    test_res = [],
):
    case.gen.train()
    case.dis.train()
    try:
        for e in range(0,num_epochs):
            for X, _ in (bar := tqdm.tqdm(dataloader)):
                X = X.to(device)
                z = sample((X.size(0), *case.gen_input_size),
                           m=case.mean, std=case.std)


                case.dis_optim.zero_grad()
                g = case.gen.forward(z)
                disc_loss = disc_tick(case.dis, X, g.detach())
                disc_loss.backward()
                case.dis_optim.step()
                
                case.gen_optim.zero_grad()
                gen_loss = gen_tick(case.dis, g)
                gen_loss.backward()
                case.gen_optim.step()

                bar.set_description(
                    f"Epoch {e}:  --  Disc:{disc_loss.item():.4f}, Gen:{gen_loss.item():.4f}")
            if (test_noise is not None) and (e % test_every_n_epochs == 0):
                test(case.gen, test_noise, test_res)
                case.gen.train()
                case.dis.train()
    except KeyboardInterrupt:
        pass
    return test_res
def save(path, case:Case):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    state = {
        "gen": case.gen.state_dict(),
        "dist": case.dis.state_dict(),
        "gen_optim": case.gen_optim.state_dict(),
        "gen_optim": case.dis_optim.state_dict(),
    }
    torch.save(state,path)
def load(path, case:Case):
    state = torch.load(path)
    case.gen.load_state_dict(state["gen"])
    case.dis.load_state_dict(state["dis"])
    case.gen_optim.load_state_dict(state["gen_optim"])
    case.dis_optim.load_state_dict(state["dis_optim"])

#### GAN

In [None]:
def simple(gen_in):
    lr = 0.0002
    m, std = 0, 1
    gen = nn.Sequential(
        Conv(gen_in, 64, 4),
        Conv(64, 64, 4),
        Conv(64, 64, 4, stride=2, padding=1),
        Conv(64, 64, 4, stride=2, padding=1),
        nn.Conv2d(64, 1, (3, 3), padding=1),
        nn.Sigmoid(),
        nn.Flatten(1, 2),
    ).to(device)
    dis = nn.Sequential(
        nn.Flatten(),
        nn.Linear(h*w, 100), nn.ReLU(),
        nn.Linear(100, 1),
    ).to(device)
    return gen, dis, lr,m,std

#### DCGAN

In [None]:
def dcgan(gen_in, channels=3):
    def init(m):
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data,0,0.02)
        elif isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data,0,0.02)
            nn.init.normal_(m.bias.data, 0, 0.02)
    # https://arxiv.org/pdf/1511.06434.pdf
    m, std = 0, 1
    lr = 0.0002
    gen = nn.Sequential(
        Conv(gen_in, 1024, 4, transposed=True),
        Conv(1024, 512, 4, stride=2, padding=1, transposed=True),
        Conv(512, 256, 4, stride=2, padding=1, transposed=True),
        Conv(256, channels, 4, stride=2, padding=1,
             act=nn.Tanh(), transposed=True, norm=False),
    ).apply(init).to(device)
    dis = nn.Sequential(
        Conv(channels, 128, 4,stride=2,padding=1, norm=False, act=nn.LeakyReLU(.2)),
        Conv(128, 64, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        Conv(64, 32, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        Conv(32, 1, 3, stride=2, padding=0, act=None,norm=False),
        # Conv(128//8, 1, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        # nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Sigmoid(),
    ).apply(init).to(device)
    gen_optim = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    dis_optim = optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))
    return gen, dis, gen_optim,dis_optim , m, std

In [None]:
generator = torch.Generator(device).manual_seed(1337)
gen_input_dim = 100
input_size = (gen_input_dim, 1, 1)
dcgan_case = setup(dcgan, gen_input_dim, input_size)
noise = sample((10, *input_size), generator=generator,
               m=dcgan_case.mean, std=dcgan_case.std)
test_res = []

In [None]:
test_res = train(
    dcgan_case,
    cifar10,
    num_epochs=20,
    test_noise=noise,
    test_every_n_epochs=2,
    test_res=test_res,
)
show([(t+1)/2 for t in test_res[-10:]], figsize=(13, 8))

In [None]:
save("./models/dcgan_tanh_v2.pt", dcgan_case)
# TODO add soft interpolation

### CGAN

In [None]:
def cgan(gen_in, ):
    def init(m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 0, 0.02)
            nn.init.normal_(m.bias.data, 0, 0.02)
    # https://arxiv.org/pdf/1511.06434.pdf
    m, std = 0, 1
    lr = 0.0002
    class Gen(nn.Module):
        def __init__(self, n_classes) -> None:
            super().__init__()
            self.aa = nn.Sequential(
                Conv(gen_in, 1024, 4, transposed=True),
                Conv(1024, 512, 4, stride=2, padding=1, transposed=True),
                Conv(512, 256, 4, stride=2, padding=1, transposed=True),
                Conv(256, 3, 4, stride=2, padding=1,
                    act=nn.Tanh(), transposed=True, norm=False),
            ).apply(init).to(device)
            self.embedding = nn.Embedding(n_classes,10)
        def forward(self,batch):
            X,y = batch
            emb = self.embedding(y)
    class Dis(nn.Module):
        pass
    dis = nn.Sequential(
        Conv(3, 128, 4, stride=2, padding=1, norm=False, act=nn.LeakyReLU(.2)),
        Conv(128, 64, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        Conv(64, 32, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        Conv(32, 1, 3, stride=2, padding=0, act=None, norm=False),
        # Conv(128//8, 1, 4, stride=2, padding=1, act=nn.LeakyReLU(.2)),
        # nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Sigmoid(),
    ).apply(init).to(device)
    gen_optim = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    dis_optim = optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))
    return gen, dis, gen_optim, dis_optim , m, std