In [None]:
from typing import List
import os
import random
import os
from typing import List, Any, Tuple, Optional

import PIL.Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import PIL
from tqdm import tqdm
from datasets import load_dataset
import click
import yaml
import pickle
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from training_utils.train_loop import train_loop, TrainingConfig
from generation import GenerationConfig
from models.gen.blocks import BaseDiffusionModel, UNet, UNetConfig
from models.gen.edm import EDM, EDMConfig
from models.gen.ddpm import DDPM, DDPMConfig
from utils.utils import EasyDict, instantiate_from_config
from data.data import SequencesDataset

def _save_sample_imgs(
    frames_real: torch.Tensor,
    frames_gen: List[torch.Tensor],
    path: str
):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)

    height_row = 5
    col_width = 5
    cols = len(frames_real)
    rows = 1 + len(frames_gen)
    fig, axes = plt.subplots(rows, cols, figsize=(col_width * cols, height_row * rows))
    for row in range(rows):
        frames = frames_real if row == 0 else frames_gen[row - 1]
        for i in range(len(frames_real)):
            axes[row, i].imshow(get_np_img(frames[i]))

    plt.subplots_adjust(wspace=0, hspace=0)

    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

def _generate_and_save_sample_imgs(
    model: BaseDiffusionModel,
    dataset: SequencesDataset,
    epoch: int,
    device: str,
    context_length: int,
    length_session = 20
):
    if len(dataset) - 1 < length_session:
        length_session = len(dataset) - 1
    index = random.randint(0, len(dataset) - 1 - length_session)
    print(dataset)
    img, last_imgs, actions = dataset[index]

    img = img.to(device)
    last_imgs = last_imgs.to(device)
    actions = actions.to(device)

    real_imgs = last_imgs.clone()
    gen_2_imgs = last_imgs.clone()
    gen_10_imgs = last_imgs.clone()
    gen_5_imgs = last_imgs.clone()
    for j in range(1, length_session):
        img, last_imgs, actions = dataset[index + j]
        img = img.to(device)
        last_imgs = last_imgs.to(device)
        actions = actions.to(device)
        real_imgs = torch.concat([real_imgs, img[None, :, :, :]], dim=0)

        gen_img = model.sample(10, img.shape, gen_10_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0))[0]
        gen_10_imgs = torch.concat([gen_10_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = model.sample(2, img.shape, gen_2_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0))[0]
        gen_2_imgs = torch.concat([gen_2_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = model.sample(5, img.shape, gen_5_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0))[0]
        gen_5_imgs = torch.concat([gen_5_imgs, gen_img[None, :, :, :]], dim=0)

    _save_sample_imgs(real_imgs, [gen_10_imgs, gen_5_imgs, gen_2_imgs], f"val_images/{epoch}.png")

In [2]:
class SequencesDataset(Dataset):
    def __init__(
        self,
        dataset,
        seq_length: int,
        transform: Optional[Any] = None,
        # one_player_possible_actions: int = 18
    ) -> None:
        super().__init__()
        self.dataset = dataset
        self.sequences: List[Tuple[List[int], List[int]]] = []
        self.transform = transform
        # self.one_player_possible_actions = one_player_possible_actions
        
        for i in tqdm(range(seq_length + 1, len(self.dataset), seq_length // 2)):
            batches = self.dataset[max(i-seq_length - 1, 0) : i]
            if batches['game_id'][0] == batches['game_id'][-1]:
                self.sequences.append(((max(i-seq_length - 1, 0), i), [batch for batch in batches['first_0']], [batch for batch in batches['second_0']]))

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        imgs_range, actions1, actions2 = self.sequences[index]
        imgs = [self.transform(self.dataset[i]['image']) for i in range(*imgs_range)]

        last_img = imgs[-1]
        # print(last_img.shape)
        # actions = [
        #     ac1 * self.one_player_possible_actions + ac2
        #     for ac1, ac2 in zip(actions1, actions2)
        # ]
        actions1 = torch.tensor(actions1)[:-1]
        actions2 = torch.tensor(actions2)[:-1]

        return (last_img, torch.stack(imgs[:-1]), torch.stack([actions1, actions2], dim=0))


def get_np_img(tensor: torch.Tensor) -> np.ndarray:
    return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)


In [3]:
transform_to_tensor = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])
dataset = load_dataset("betteracs/boxing_atari_diffusion")

In [4]:
seq_dataset = SequencesDataset(
    dataset=dataset['train'].select(range(0, 10000)),
    seq_length=8,
    transform=transform_to_tensor
)

100%|██████████| 2498/2498 [00:05<00:00, 491.84it/s]


In [5]:
train_set, val_set = torch.utils.data.random_split(seq_dataset, [2400, 88])

In [6]:
experiments = [
    # {
    #     "name": "Experiment1",
    #     "file": "seq1",
    #     "seq_length": 1 
    # },
    # {
    #     "name": "Experiment2",
    #     "file": "seq2",
    #     "seq_length": 2 
    # },
    # {
    #     "name": "Experiment3",
    #     "file": "seq3",
    #     "seq_length": 4 
    # },
    # {
    #     "name": "Experiment4",
    #     "file": "seq4",
    #     "seq_length": 8 
    # },
    {
        "name": "Experiment6",
        "file": "seq5_with_ae",
        "seq_length": 8 
    },
]

for experiment in experiments:
    print("-" * 100)
    print(experiment['name'])

    args = {
        "config": "snake-diffusion/config/Diffusion.yaml",
        "model_type": "ddpm",
        "output_prefix": f"output/{experiment['file']}",
        "last_checkpoint": "",
        "gen_val_images": True,

        # "dataset": "data/sequences",
        # "output_loader": "output/sequences_loader.pkl",
    }
    options = EasyDict(**args)
    with open(options.config, 'r') as f:
        config = EasyDict(**yaml.safe_load(f))

    config['generation']['context_length'] = experiment['seq_length']
    config['generation']['actions_count'] = 18

    # seq_train_data = SequencesDataset(
    #     dataset=train_dataset,
    #     seq_length=experiment['seq_length'],
    #     transform=transform_to_tensor
    # )

    # seq_val_data = SequencesDataset(
    #     dataset=valid_dataset,
    #     seq_length=experiment['seq_length'],
    #     transform=transform_to_tensor
    # )

    train_dataloader = DataLoader(
        train_set,
        batch_size=config['training']['batch_size'],
        shuffle=True,
        num_workers=config['training']['num_workers']
    )

    val_dataloader = DataLoader(
        val_set,
        batch_size=config['training']['batch_size'],
        num_workers=config['training']['num_workers']
    )
    training_config = TrainingConfig(**config.training)
    generation_config = GenerationConfig(**config.generation)

    # generation_config.image_size = 160
    # generation_config.context_length = 8
    # training_config.batch_size = 2

    transform_to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((.5,.5,.5), (.5,.5,.5))
    ])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # For Mac OS
    if torch.backends.mps.is_available():
        device = "mps"

    model: BaseDiffusionModel
    if options.model_type == "edm":
        config = EDMConfig(**instantiate_from_config(config.edm))
        model = EDM.from_config(
            config=config,
            context_length=generation_config.context_length,
            device=device,
            model=UNet.from_config(
                config=config.unet,
                in_channels=generation_config.unet_input_channels,
                out_channels=generation_config.output_channels,
                actions_count=generation_config.actions_count,
                seq_length=generation_config.context_length
            )
        )
    elif options.model_type == "ddpm":
        config = DDPMConfig(**instantiate_from_config(config.ddpm))
        model = DDPM.from_config(
            config=config,
            context_length=generation_config.context_length,
            device=device,
            model=UNet.from_config(
                config=config.unet,
                in_channels=generation_config.unet_input_channels,
                out_channels=generation_config.output_channels,
                actions_count=generation_config.actions_count,
                seq_length=generation_config.context_length,
                T=config.T,
                # player_autoencoder=config.player_autoencoder,
            )
        )

    def gen_val_images(epoch: int):
        _generate_and_save_sample_imgs(model, val_set, epoch, device, generation_config.context_length)


    print(f"Start training {options.model_type}")
    train_loop(
        model=model,
        device=device,
        config=training_config,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        output_path_prefix=options.output_prefix,
        existing_model_path=None, #options["last_checkpoint"],
        gen_imgs=gen_val_images if options.gen_val_images else None
    )

print("-"*100)


----------------------------------------------------------------------------------------------------
Experiment6
Start training ddpm


loss for epoch 1: 0.4391:   7%|▋         | 41/600 [01:37<22:06,  2.37s/it]


KeyboardInterrupt: 