## Clone repository

In [None]:
!git clone https://github.com/mukundan-chariar1/diffusion_transformer_pytorch.git
%cd /content//diffusion_transformer_pytorch

!mkdir data
!mkdir latent_data
!mkdir weights

## Install pip libraries

In [None]:
!pip install -U pip
!pip install -r requirements.txt

## Download data

### Latent space data

- download from this link: https://drive.google.com/file/d/1u1O0DrS3aiKbLQQIGHg0xHjXEpAKBlgs/view?usp=sharing
- put in folder latent_data (created earlier)

In [None]:
%cd data
!curl -L -o ./landscape-pictures.zip  https://www.kaggle.com/api/v1/datasets/download/arnaud58/landscape-pictures
!unzip ./landscape-pictures.zip
!rm ./landscape-pictures.zip
%cd ./../latent_data
!unzip ./latent_data.zip
!rm ./latent_data.zip

## Begin training
### Imports

In [None]:
%cd /content/diffusion_transformer_pytorch

import torch
import torchvision
from torchsummary import summary
from torch import nn

from diffusers.models import AutoencoderKL

import random
import numpy as np

from dataloader import *
from embedding import *
from diffusion import *

from loss import *

from autoencoder import *
from transformer import *

from train_DiT import *
from train_vae import *

from utils import *
from testing import*

import json

from IPython import get_ipython
from IPython.display import display, clear_output

device ='cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device chosen: {device}')

seed=42

random.seed(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

DATA_DIR="latent_data"
img_size=(256, 256)
latent_size=(32, 32)
latent_channels=4

In [None]:
sanity_check(ImageDataset('data',
                                   transforms=torchvision.transforms.Compose([
                                          torchvision.transforms.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
                                          torchvision.transforms.ToTensor(),])),)

In [None]:
model_config=dict(
    patch_size=2,
    num_layers=8,
    num_heads=6,
    T=1000
)
config_string=f"{model_config['patch_size']}_{model_config['num_layers']}_{model_config['num_heads']}_{model_config['T']}"
!mkdir weights/{config_string}

with open(f"weights/{config_string}/{config_string}.json","w") as f: json.dump(model_config, f, indent=2)

print(f'All files will be saved under weights/{config_string}')

In [None]:
model=DiT(
        input_size=latent_size,
        embed_dim=384,
        b_0=1e-4,
        b_T=2e-2,
        hidden_dim=384,
        in_chans=latent_channels,
        schedule_type='linear',
        **model_config
        )

In [None]:
train_transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5]*3, [0.5]*3),
        ])

test_dataset=train_dataset=ImageDatasetTransformer(DATA_DIR, train_transforms)

In [None]:
train_via_iter(
        model,
        train_dataset,
        n_iters=1000,
        plot_freq=50,
        optimizer_name='AdamW',
        optimizer_config={"lr": 1e-4, "weight_decay": 0},
        batch_size=32,
        config_string=config_string
        )

In [None]:
torch.save(model.state_dict(), f'weights/{config_string}/{config_string}.pth')

In [None]:
calculate_fid(config_str)
concat_images(config_str)

In [None]:
!zip weights/{config_string}.zip weights/{config_string}/*