In [None]:
from typing import List
import os
import random
import os
from typing import List, Any, Tuple, Optional

import PIL.Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import PIL
from tqdm import tqdm
from datasets import load_dataset
import click
import yaml
import pickle
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from training_utils.train_loop import train_loop, TrainingConfig
from generation import GenerationConfig
from models.gen.blocks import BaseDiffusionModel, UNet, UNetConfig
from models.gen.edm import EDM, EDMConfig
from models.gen.ddpm import DDPM, DDPMConfig
from utils.utils import EasyDict, instantiate_from_config
from data.data import SequencesDataset

def _save_sample_imgs(
    frames_real: torch.Tensor,
    frames_gen: List[torch.Tensor],
    path: str
):
    height_row = 5
    col_width = 5
    cols = len(frames_real)
    rows = 1 + len(frames_gen)
    fig, axes = plt.subplots(rows, cols, figsize=(col_width * cols, height_row * rows))
    for row in range(rows):
        frames = frames_real if row == 0 else frames_gen[row - 1]
        for i in range(len(frames_real)):
            axes[row, i].imshow(SequencesDataset.get_np_img(frames[i]))

    plt.subplots_adjust(wspace=0, hspace=0)

    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

def _generate_and_save_sample_imgs(
    model: BaseDiffusionModel,
    dataset: SequencesDataset,
    epoch: int,
    device: str,
    context_length: int,
    length_session = 20
):
    index = random.randint(0, len(dataset) - length_session - 1)
    start, end = dataset.get_session(index, length_session)

    data = dataset.dataset[start:end]

    actions1 = data['first_0']
    actions2 = data['second_0']
    actions1 = torch.tensor(actions1)
    actions2 = torch.tensor(actions2)
    actions = torch.stack([actions1, actions2], dim=0).to(device)

    # Denoise steps
    gen_10_imgs = None
    gen_5_imgs = None
    gen_2_imgs = None
    
    for i in range(0, length_session- context_length):
        if gen_10_imgs is None or gen_5_imgs is None or gen_2_imgs is None:
            imgs = dataset.get_images(start, start + context_length)
            imgs = torch.stack(imgs, dim=0).to(device)
            
            gen_10_imgs = imgs.clone()
            gen_5_imgs = imgs.clone()
            gen_2_imgs = imgs.clone()

            img = imgs[-1]

        prev_actions = actions[:, i: i + context_length].unsqueeze(0)
        gen_img = model.sample(10, img.shape, gen_10_imgs[-context_length:].unsqueeze(0), prev_actions)[0]
        gen_10_imgs = torch.concat([gen_10_imgs, gen_img[None, :, :, :]], dim=0)

        gen_img = model.sample(5, img.shape, gen_5_imgs[-context_length:].unsqueeze(0), prev_actions)[0]
        gen_5_imgs = torch.concat([gen_5_imgs, gen_img[None, :, :, :]], dim=0)

        gen_img = model.sample(2, img.shape, gen_2_imgs[-context_length:].unsqueeze(0), prev_actions)[0]
        gen_2_imgs = torch.concat([gen_2_imgs, gen_img[None, :, :, :]], dim=0)
    
    real_images = torch.stack(dataset.get_images(start, end), dim=0)
    _save_sample_imgs(real_images, [gen_10_imgs, gen_5_imgs, gen_2_imgs], f"val_images/{epoch}.png")

In [None]:
transform_to_tensor = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])


args = {
    "config": "../config/Diffusion.yaml",
    "model_type": "ddpm",
    "output_prefix": f"output/boxing",
    "last_checkpoint": "",
    "gen_val_images": True,

    # "dataset": "data/sequences",
    # "output_loader": "output/sequences_loader.pkl",
}
options = EasyDict(**args)
with open(options.config, 'r') as f:
    config = EasyDict(**yaml.safe_load(f))

In [None]:

# Do not shuffle the dataset.
dataset = load_dataset("betteracs/boxing_atari_diffusion")

train_seq_dataset = SequencesDataset(
    dataset=dataset['train'].select(range(450_000)),
    seq_length=config['generation']['context_length'],
    transform=transform_to_tensor,
    seed=42
)

val_seq_dataset = SequencesDataset(
    dataset=dataset['train'].select(range(450_000, 500_000)),
    seq_length=config['generation']['context_length'],
    transform=transform_to_tensor,
    seed=42
)

train_dataloader = DataLoader(
    train_seq_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=config['training']['num_workers']
)

val_dataloader = DataLoader(
    val_seq_dataset,
    batch_size=config['training']['batch_size'],
    num_workers=config['training']['num_workers']
)

training_config = TrainingConfig(**config.training)
generation_config = GenerationConfig(**config.generation)

# generation_config.image_size = 160
# generation_config.context_length = 8
# training_config.batch_size = 2

transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"

model: BaseDiffusionModel
if options.model_type == "edm":
    config = EDMConfig(**instantiate_from_config(config.edm))
    model = EDM.from_config(
        config=config,
        context_length=generation_config.context_length,
        device=device,
        model=UNet.from_config(
            config=config.unet,
            in_channels=generation_config.unet_input_channels,
            out_channels=generation_config.output_channels,
            actions_count=generation_config.actions_count,
            seq_length=generation_config.context_length
        )
    )
elif options.model_type == "ddpm":
    config = DDPMConfig(**instantiate_from_config(config.ddpm))
    model = DDPM.from_config(
        config=config,
        context_length=generation_config.context_length,
        device=device,
        model=UNet.from_config(
            config=config.unet,
            in_channels=generation_config.unet_input_channels,
            out_channels=generation_config.output_channels,
            actions_count=generation_config.actions_count,
            seq_length=generation_config.context_length,
            T=config.T,
            # player_autoencoder=config.player_autoencoder,
        )
    )

def gen_val_images(epoch: int):
    _generate_and_save_sample_imgs(model, val_seq_dataset, epoch, device, generation_config.context_length)


print(f"Start training {options.model_type}")
training_losses, val_losses = train_loop(
    model=model,
    device=device,
    config=training_config,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    output_path_prefix=options.output_prefix,
    existing_model_path=None, #options["last_checkpoint"],
    gen_imgs=gen_val_images if options.gen_val_images else None
)
print("-"*100)