In [1]:
#We dogitclonebutwejust usethe data. We don't use the code.
!git clone https://github.com/sergeytulyakov/mocogan.git
!ls mocogan/data/shapes


Cloning into 'mocogan'...
remote: Enumerating objects: 8366, done.[K
remote: Total 8366 (delta 0), reused 0 (delta 0), pack-reused 8366 (from 1)[K
Receiving objects: 100% (8366/8366), 83.29 MiB | 27.56 MiB/s, done.
Resolving deltas: 100% (83/83), done.
0  1


#Import

In [2]:
import os

import PIL

import functools
import IPython.display

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# models


In [None]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable

import numpy as np

if torch.cuda.is_available():
    T = torch.cuda
else:
    T = torch


class Noise(nn.Module):
    def __init__(self, use_noise, sigma=0.2):
        super(Noise, self).__init__()
        self.use_noise = use_noise
        self.sigma = sigma

    def forward(self, x):
        if self.use_noise:
            return x + self.sigma * Variable(T.FloatTensor(x.size()).normal_(), requires_grad=False)
        return x


class ImageDiscriminator(nn.Module):
    def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None):
        super(ImageDiscriminator, self).__init__()

        self.use_noise = use_noise

        self.main = nn.Sequential(
            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        h = self.main(input).squeeze()
        return h, None


class PatchImageDiscriminator(nn.Module):
    def __init__(self, n_channels, ndf=64, use_noise=False, noise_sigma=None):
        super(PatchImageDiscriminator, self).__init__()

        self.use_noise = use_noise

        self.main = nn.Sequential(
            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv2d(ndf * 4, 1, 4, 2, 1, bias=False),
        )

    def forward(self, input):
        h = self.main(input).squeeze()
        return h, None


class PatchVideoDiscriminator(nn.Module):
    def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64):
        super(PatchVideoDiscriminator, self).__init__()

        self.n_channels = n_channels
        self.n_output_neurons = n_output_neurons
        self.use_noise = use_noise
        self.bn_use_gamma = bn_use_gamma

        self.main = nn.Sequential(
            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(ndf * 4, 1, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
        )

    def forward(self, input):
        h = self.main(input).squeeze()
        return h, None


class VideoDiscriminator(nn.Module):
    def __init__(self, n_channels, n_output_neurons=1, bn_use_gamma=True, use_noise=False, noise_sigma=None, ndf=64):
        super(VideoDiscriminator, self).__init__()

        self.n_channels = n_channels
        self.n_output_neurons = n_output_neurons
        self.use_noise = use_noise
        self.bn_use_gamma = bn_use_gamma

        self.main = nn.Sequential(
            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(n_channels, ndf, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(ndf, ndf * 2, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(ndf * 2, ndf * 4, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            Noise(use_noise, sigma=noise_sigma),
            nn.Conv3d(ndf * 4, ndf * 8, 4, stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
            nn.BatchNorm3d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(ndf * 8, n_output_neurons, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        h = self.main(input).squeeze()
        return h, None


class CategoricalVideoDiscriminator(VideoDiscriminator):
    def __init__(self, n_channels, dim_categorical, n_output_neurons=1, use_noise=False, noise_sigma=None):
        super(CategoricalVideoDiscriminator, self).__init__(
            n_channels=n_channels,
            n_output_neurons=n_output_neurons + dim_categorical,
            use_noise=use_noise,
            noise_sigma=noise_sigma
        )
        self.dim_categorical = dim_categorical

    def split(self, input):
        return input[:, :input.size(1) - self.dim_categorical], input[:, input.size(1) - self.dim_categorical:]

    def forward(self, input):
        h, _ = super(CategoricalVideoDiscriminator, self).forward(input)
        labels, categ = self.split(h)
        return labels, categ


class VideoGenerator(nn.Module):
    def __init__(
        self, n_channels, dim_z_content, dim_z_category, dim_z_motion, video_length, ngf=64
    ):
        super(VideoGenerator, self).__init__()

        self.n_channels = n_channels
        self.dim_z_content = dim_z_content
        self.dim_z_category = dim_z_category
        self.dim_z_motion = dim_z_motion
        self.video_length = video_length

        dim_z = dim_z_motion + dim_z_category + dim_z_content
        self.recurrent = nn.GRUCell(dim_z_motion, dim_z_motion)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(dim_z, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, self.n_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def sample_z_m(self, num_samples, video_len=None):
        video_len = video_len if video_len is not None else self.video_length
        h_t = [self.get_gru_initial_state(num_samples)]

        for frame_num in range(video_len):
            e_t = self.get_iteration_noise(num_samples)
            h_t.append(self.recurrent(e_t, h_t[-1]))

        z_m_t = [h_k.view(-1, 1, self.dim_z_motion) for h_k in h_t]
        z_m = torch.cat(z_m_t[1:], dim=1).view(-1, self.dim_z_motion)

        return z_m

    def sample_z_categ(self, num_samples, video_len):
        video_len = video_len if video_len is not None else self.video_length

        if self.dim_z_category <= 0:
            return None, np.zeros(num_samples)

        classes_to_generate = np.random.randint(self.dim_z_category, size=num_samples)
        one_hot = np.zeros((num_samples, self.dim_z_category), dtype=np.float32)
        one_hot[np.arange(num_samples), classes_to_generate] = 1
        one_hot_video = np.repeat(one_hot, video_len, axis=0)

        one_hot_video = torch.from_numpy(one_hot_video)

        if torch.cuda.is_available():
            one_hot_video = one_hot_video.cuda()

        return Variable(one_hot_video), classes_to_generate

    def sample_z_content(self, num_samples, video_len=None):
        video_len = video_len if video_len is not None else self.video_length

        content = np.random.normal(0, 1, (num_samples, self.dim_z_content)).astype(np.float32)
        content = np.repeat(content, video_len, axis=0)
        content = torch.from_numpy(content)
        if torch.cuda.is_available():
            content = content.cuda()
        return Variable(content)

    def sample_z_video(self, num_samples, video_len=None):
        z_content = self.sample_z_content(num_samples, video_len)
        z_category, z_category_labels = self.sample_z_categ(num_samples, video_len)
        z_motion = self.sample_z_m(num_samples, video_len)

        if z_category is not None:
            z = torch.cat([z_content, z_category, z_motion], dim=1)
        else:
            z = torch.cat([z_content, z_motion], dim=1)

        return z, z_category_labels

    def sample_videos(self, num_samples, video_len=None):
        video_len = video_len if video_len is not None else self.video_length

        z, z_category_labels = self.sample_z_video(num_samples, video_len)

        h = self.main(z.view(z.size(0), z.size(1), 1, 1))
        h = h.view(h.size(0) // video_len, video_len, self.n_channels, h.size(3), h.size(3))

        z_category_labels = torch.from_numpy(z_category_labels)

        if torch.cuda.is_available():
            z_category_labels = z_category_labels.cuda()

        h = h.permute(0, 2, 1, 3, 4)
        return h, Variable(z_category_labels, requires_grad=False)

    def sample_images(self, num_samples):
        z, z_category_labels = self.sample_z_video(num_samples * self.video_length * 2)

        j = np.sort(np.random.choice(z.size(0), num_samples, replace=False)).astype(np.int64)
        z = z[j, ::]
        z = z.view(z.size(0), z.size(1), 1, 1)
        h = self.main(z)

        return h, None

    def get_gru_initial_state(self, num_samples):
        return Variable(T.FloatTensor(num_samples, self.dim_z_motion).normal_())

    def get_iteration_noise(self, num_samples):
        return Variable(T.FloatTensor(num_samples, self.dim_z_motion).normal_())

#data

In [None]:
import os
import tqdm
import pickle
import numpy as np
import torch.utils.data
from torchvision.datasets import ImageFolder
import PIL


class VideoFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder, cache, min_len=32):
        dataset = ImageFolder(folder)
        self.total_frames = 0
        self.lengths = []
        self.images = []

        if cache is not None and os.path.exists(cache) and os.path.getsize(cache) != 0:
            with open(cache, 'rb') as f:
                self.images, self.lengths = pickle.load(f)
        else:
            for idx, (im, categ) in enumerate(
                    tqdm.tqdm(dataset, desc="Counting total number of frames")):
                img_path, _ = dataset.imgs[idx]
                shorter, longer = min(im.width, im.height), max(im.width, im.height)
                length = longer // shorter
                if length >= min_len:
                    self.images.append((img_path, categ))
                    self.lengths.append(length)

            if cache is not None:
                with open(cache, 'wb') as f:
                    pickle.dump((self.images, self.lengths), f)

        self.cumsum = np.cumsum([0] + self.lengths)
        print("Total number of frames {}".format(np.sum(self.lengths)))

    def __getitem__(self, item):
        path, label = self.images[item]
        im = PIL.Image.open(path)
        return im, label

    def __len__(self):
        return len(self.images)


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transforms = transform if transform is not None else lambda x: x

    def __getitem__(self, item):
        if item != 0:
            video_id = np.searchsorted(self.dataset.cumsum, item) - 1
            frame_num = item - self.dataset.cumsum[video_id] - 1
        else:
            video_id = 0
            frame_num = 0

        video, target = self.dataset[video_id]
        video = np.array(video)

        horizontal = video.shape[1] > video.shape[0]

        if horizontal:
            i_from, i_to = video.shape[0] * frame_num, video.shape[0] * (frame_num + 1)
            frame = video[:, i_from: i_to, ::]
        else:
            i_from, i_to = video.shape[1] * frame_num, video.shape[1] * (frame_num + 1)
            frame = video[i_from: i_to, :, ::]

        if frame.shape[0] == 0:
            print("video {}. From {} to {}. num {}".format(video.shape, i_from, i_to, item))

        return {"images": self.transforms(frame), "categories": target}

    def __len__(self):
        return self.dataset.cumsum[-1]


class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, video_length, every_nth=1, transform=None):
        self.dataset = dataset
        self.video_length = video_length
        self.every_nth = every_nth
        self.transforms = transform if transform is not None else lambda x: x

    def __getitem__(self, item):
        video, target = self.dataset[item]
        video = np.array(video)

        horizontal = video.shape[1] > video.shape[0]
        shorter, longer = min(video.shape[0], video.shape[1]), max(video.shape[0], video.shape[1])
        video_len = longer // shorter

        # videos can be of various length, we randomly sample sub-sequences
        if video_len >= self.video_length * self.every_nth:
            needed = self.every_nth * (self.video_length - 1)
            gap = video_len - needed
            start = 0 if gap == 0 else np.random.randint(0, gap, 1)[0]
            subsequence_idx = np.linspace(start, start + needed, self.video_length, endpoint=True, dtype=np.int32)
        elif video_len >= self.video_length:
            subsequence_idx = np.arange(0, self.video_length)
        else:
            raise Exception("Length is too short id - {}, len - {}").format(self.dataset[item], video_len)

        frames = np.split(video, video_len, axis=1 if horizontal else 0)
        selected = np.array([frames[s_id] for s_id in subsequence_idx])

        return {"images": self.transforms(selected), "categories": target}

    def __len__(self):
        return len(self.dataset)


class ImageSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transforms = transform

    def __getitem__(self, index):
        result = {}
        for k in self.dataset.keys:
            result[k] = np.take(self.dataset.get_data()[k], index, axis=0)

        if self.transforms is not None:
            for k, transform in self.transforms.items():
                result[k] = transform(result[k])

        return result

    def __len__(self):
        return self.dataset.get_data()[self.dataset.keys[0]].shape[0]


class VideoSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, video_length, every_nth=1, transform=None):
        self.dataset = dataset
        self.video_length = video_length
        self.unique_ids = np.unique(self.dataset.get_data()['video_ids'])
        self.every_nth = every_nth
        self.transforms = transform

    def __getitem__(self, item):
        result = {}
        ids = self.dataset.get_data()['video_ids'] == self.unique_ids[item]
        ids = np.squeeze(np.squeeze(np.argwhere(ids)))
        for k in self.dataset.keys:
            result[k] = np.take(self.dataset.get_data()[k], ids, axis=0)

        subsequence_idx = None
        print(result[k].shape[0])

        # videos can be of various length, we randomly sample sub-sequences
        if result[k].shape[0] > self.video_length:
            needed = self.every_nth * (self.video_length - 1)
            gap = result[k].shape[0] - needed
            start = 0 if gap == 0 else np.random.randint(0, gap, 1)[0]
            subsequence_idx = np.linspace(start, start + needed, self.video_length, endpoint=True, dtype=np.int32)
        elif result[k].shape[0] == self.video_length:
            subsequence_idx = np.arange(0, self.video_length)
        else:
            print("Length is too short id - {}, len - {}".format(self.unique_ids[item], result[k].shape[0]))

        if subsequence_idx:
            for k in self.dataset.keys:
                result[k] = np.take(result[k], subsequence_idx, axis=0)
        else:
            print(result[self.dataset.keys[0]].shape)

        if self.transforms is not None:
            for k, transform in self.transforms.items():
                result[k] = transform(result[k])

        return result

    def __len__(self):
        return len(self.unique_ids)

#Logger

In [None]:
import PIL
import tensorflow as tf
import numpy as np

try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x


class Logger(object):
    def __init__(self, log_dir, suffix=None):
        self.writer = tf.summary.create_file_writer(log_dir, filename_suffix=suffix)

    def scalar_summary(self, tag, value, step):
        with self.writer.as_default():
            tf.summary.scalar(tag, value, step=step)
            self.writer.flush()  # Optional: flush the writer to ensure the summary is written

    def image_summary(self, tag, images, step):
        img_summaries = []
        for i, img in enumerate(images):
            # Convert the image to a format suitable for saving
            img = np.clip(img, 0, 255).astype(np.uint8)  # Ensure values are in [0, 255]

            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()

            # Create a PIL image from the numpy array
            pil_img = PIL.Image.fromarray(img)
            pil_img.save(s, format="png")  # Save the image to the BytesIO object

            # Convert the byte string to a tensor
            image_tensor = tf.image.decode_png(s.getvalue())

            # Append the image summary to the img_summaries list
            img_summaries.append((f'{tag}/{i}', image_tensor))

        # Write the summaries to TensorBoard
        with self.writer.as_default():
            for img_tag, img_tensor in img_summaries:
                tf.summary.image(img_tag, tf.expand_dims(img_tensor, 0), step=step)

        self.writer.flush()

    def video_summary(self, tag, videos, step):
        sh = list(videos.shape)
        sh[-1] = 1

        separator = np.zeros(sh, dtype=videos.dtype)
        videos = np.concatenate([videos, separator], axis=-1)

        for i, vid in enumerate(videos):
            # Concat a video
            try:
                s = StringIO()
            except:
                s = BytesIO()

            v = vid.transpose(1, 2, 3, 0)
            v = [np.squeeze(f) for f in np.split(v, v.shape[0], axis=0)]
            img = np.concatenate(v, axis=1)[:, :-1, :]

            # Нормализуем значения
            img = (img * 255).astype(np.uint8)

            # Convert and save image using Pillow
            pil_img = PIL.Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))  # Ensure correct format
            pil_img.save(s, format="png")  # Save to BytesIO
            encoded_image_string = s.getvalue()  # Get the byte value of the image

            # Write the summary directly to TensorBoard
            with self.writer.as_default():
                tf.summary.image(f'{tag}/{i}', [tf.image.decode_png(encoded_image_string)], step=step)

        self.writer.flush()


#Trainer


In [None]:
import os
import time

from tqdm import tqdm

import numpy as np

import torch
from torch import nn

from torch.autograd import Variable
import torch.optim as optim

if torch.cuda.is_available():
    T = torch.cuda
else:
    T = torch


def images_to_numpy(tensor):
    generated = tensor.data.cpu().numpy().transpose(0, 2, 3, 1)
    generated[generated < -1] = -1
    generated[generated > 1] = 1
    generated = (generated + 1) / 2 * 255
    return generated.astype('uint8')


def videos_to_numpy(tensor):
    generated = tensor.data.cpu().numpy().transpose(0, 1, 2, 3, 4)
    generated[generated < -1] = -1
    generated[generated > 1] = 1
    generated = (generated + 1) / 2 * 255
    return generated.astype('uint8')


def one_hot_to_class(tensor):
    a, b = np.nonzero(tensor)
    return np.unique(b).astype(np.int32)


class Trainer(object):
    def __init__(
        self,
        image_sampler,
        video_sampler,
        log_interval,
        train_batches,
        log_folder,
        use_cuda=False,
        use_infogan=True,
        use_categories=True
    ):
        self.use_categories = use_categories

        self.gan_criterion = nn.BCEWithLogitsLoss()
        self.category_criterion = nn.CrossEntropyLoss()

        self.image_sampler = image_sampler
        self.video_sampler = video_sampler

        self.video_batch_size = self.video_sampler.batch_size
        self.image_batch_size = self.image_sampler.batch_size

        self.log_interval = log_interval
        self.train_batches = train_batches

        self.log_folder = log_folder

        self.use_cuda = use_cuda
        self.use_infogan = use_infogan

        self.image_enumerator = None
        self.video_enumerator = None

    @staticmethod
    def ones_like(tensor, val=1.):
        return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False)

    def compute_gan_loss(self, discriminator, sample_true, sample_fake, is_video):
        real_batch = sample_true()

        batch_size = real_batch['images'].size(0)
        fake_batch, generated_categories = sample_fake(batch_size)

        real_labels, real_categorical = discriminator(Variable(real_batch['images']))
        fake_labels, fake_categorical = discriminator(fake_batch)

        fake_gt, real_gt = self.get_gt_for_discriminator(batch_size, real=0.)

        l_discriminator = self.gan_criterion(real_labels, real_gt) + \
                          self.gan_criterion(fake_labels, fake_gt)

        # update image discriminator here

        # sample again for videos

        # update video discriminator

        # sample again
        # - videos
        # - images

        # l_vidoes + l_images -> l
        # l.backward()
        # opt.step()


        #  sample again and compute for generator

        fake_gt = self.get_gt_for_generator(batch_size)
        # to real_gt
        l_generator = self.gan_criterion(fake_labels, fake_gt)

        if is_video:

            # Ask the video discriminator to learn categories from training videos
            categories_gt = Variable(torch.squeeze(real_batch['categories'].long()))
            l_discriminator += self.category_criterion(real_categorical, categories_gt)

            if self.use_infogan:
                # Ask the generator to generate categories recognizable by the discriminator
                l_generator += self.category_criterion(fake_categorical, generated_categories)

        return l_generator, l_discriminator

    def sample_real_image_batch(self):
        if self.image_enumerator is None:
            self.image_enumerator = enumerate(self.image_sampler)

        batch_idx, batch = next(self.image_enumerator)
        b = batch
        if self.use_cuda:
            for k, v in batch.items():
                b[k] = v.cuda()

        if batch_idx == len(self.image_sampler) - 1:
            self.image_enumerator = enumerate(self.image_sampler)

        return b

    def sample_real_video_batch(self):
        if self.video_enumerator is None:
            self.video_enumerator = enumerate(self.video_sampler)

        batch_idx, batch = next(self.video_enumerator)
        b = batch
        if self.use_cuda:
            for k, v in batch.items():
                b[k] = v.cuda()

        if batch_idx == len(self.video_sampler) - 1:
            self.video_enumerator = enumerate(self.video_sampler)

        return b

    def train_discriminator(self, discriminator, sample_true, sample_fake, opt, batch_size, use_categories):
        opt.zero_grad()

        real_batch = sample_true()
        batch = Variable(real_batch['images'], requires_grad=False)

        # util.show_batch(batch.data)

        fake_batch, generated_categories = sample_fake(batch_size)

        real_labels, real_categorical = discriminator(batch)
        fake_labels, fake_categorical = discriminator(fake_batch.detach())

        ones = self.ones_like(real_labels)
        zeros = self.zeros_like(fake_labels)

        l_discriminator = self.gan_criterion(real_labels, ones) + \
                          self.gan_criterion(fake_labels, zeros)

        if use_categories:
            # Ask the video discriminator to learn categories from training videos
            categories_gt = Variable(torch.squeeze(real_batch['categories'].long()), requires_grad=False)
            l_discriminator += self.category_criterion(real_categorical.squeeze(), categories_gt)

        l_discriminator.backward()
        opt.step()

        return l_discriminator

    def train_generator(
        self,
        image_discriminator,
        video_discriminator,
        sample_fake_images,
        sample_fake_videos,
        opt
    ):
        opt.zero_grad()

        # train on images

        fake_batch, generated_categories = sample_fake_images(self.image_batch_size)
        fake_labels, fake_categorical = image_discriminator(fake_batch)
        all_ones = self.ones_like(fake_labels)

        l_generator = self.gan_criterion(fake_labels, all_ones)

        # train on videos

        fake_batch, generated_categories = sample_fake_videos(self.video_batch_size)
        fake_labels, fake_categorical = video_discriminator(fake_batch)
        all_ones = self.ones_like(fake_labels)

        l_generator += self.gan_criterion(fake_labels, all_ones)

        if self.use_infogan:
            # Ask the generator to generate categories recognizable by the discriminator
            l_generator += self.category_criterion(fake_categorical.squeeze(), generated_categories)

        l_generator.backward()
        opt.step()

        return l_generator

    def train(self, generator, image_discriminator, video_discriminator, num_epochs):
        if self.use_cuda:
            generator.cuda()
            image_discriminator.cuda()
            video_discriminator.cuda()

        logger = Logger(self.log_folder)

        # create optimizers
        opt_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001)
        opt_image_discriminator = optim.Adam(
            image_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001
        )
        opt_video_discriminator = optim.Adam(
            video_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001
        )

        # training loop

        def sample_fake_image_batch(batch_size):
            return generator.sample_images(batch_size)

        def sample_fake_video_batch(batch_size):
            return generator.sample_videos(batch_size)

        def init_logs():
            return {'l_gen': 0, 'l_image_dis': 0, 'l_video_dis': 0}

        for epoch in range(num_epochs):

            batch_num = 0
            logs = init_logs()
            epoch_logs = init_logs()  # Accumulate losses for the entire epoch
            start_time = time.time()

            with tqdm(total=self.train_batches, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
                while batch_num < self.train_batches:
                    generator.train()
                    image_discriminator.train()
                    video_discriminator.train()

                    opt_generator.zero_grad()
                    opt_video_discriminator.zero_grad()

                    # train image discriminator
                    l_image_dis = self.train_discriminator(
                        image_discriminator, self.sample_real_image_batch,
                        sample_fake_image_batch, opt_image_discriminator,
                        self.image_batch_size, use_categories=False
                    )

                    # train video discriminator
                    l_video_dis = self.train_discriminator(
                        video_discriminator, self.sample_real_video_batch,
                        sample_fake_video_batch, opt_video_discriminator,
                        self.video_batch_size, use_categories=self.use_categories
                    )

                    # train generator
                    l_gen = self.train_generator(
                        image_discriminator, video_discriminator,
                        sample_fake_image_batch, sample_fake_video_batch,
                        opt_generator
                    )

                    logs['l_gen'] += l_gen.item()
                    logs['l_image_dis'] += l_image_dis.item()
                    logs['l_video_dis'] += l_video_dis.item()

                    # Accumulate epoch losses
                    epoch_logs['l_gen'] += l_gen.item()
                    epoch_logs['l_image_dis'] += l_image_dis.item()
                    epoch_logs['l_video_dis'] += l_video_dis.item()

                    batch_num += 1

                    if batch_num % self.log_interval == 0:
                        log_string = "Batch %d" % batch_num
                        for k, v in logs.items():
                            log_string += " [%s] %5.3f" % (k, v / self.log_interval)

                        log_string += ". Took %5.2f" % (time.time() - start_time)

                        print(log_string)

                        for tag, value in logs.items():
                            logger.scalar_summary(tag, value / self.log_interval, batch_num)

                        logs = init_logs()
                        start_time = time.time()

                        generator.eval()

                        images, _ = sample_fake_image_batch(self.image_batch_size)
                        logger.image_summary("Images", images_to_numpy(images), batch_num)

                        videos, _ = sample_fake_video_batch(self.video_batch_size)
                        logger.video_summary("Videos", videos_to_numpy(videos), batch_num)

                        # Save generator weights every 5 epochs
                        if (epoch + 1) % 5 == 0:
                            torch.save(generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % (epoch + 1)))

                    pbar.update(1)

            # Calculate average losses for the epoch
            avg_l_gen = epoch_logs['l_gen'] / batch_num
            avg_l_image_dis = epoch_logs['l_image_dis'] / batch_num
            avg_l_video_dis = epoch_logs['l_video_dis'] / batch_num

            print(f"Epoch {epoch + 1}/{num_epochs} - Average losses: "
                  f"Generator: {avg_l_gen:.4f}, "
                  f"Image Discriminator: {avg_l_image_dis:.4f}, "
                  f"Video Discriminator: {avg_l_video_dis:.4f}")

            # Log average epoch losses
            logger.scalar_summary('avg_l_gen', avg_l_gen, epoch + 1)
            logger.scalar_summary('avg_l_image_dis', avg_l_image_dis, epoch + 1)
            logger.scalar_summary('avg_l_video_dis', avg_l_video_dis, epoch + 1)

            if batch_num >= self.train_batches:
                torch.save(generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % batch_num))
                break


#Main

In [None]:
import os
import PIL

import functools

import torch
from torch.utils.data import DataLoader
from torchvision import transforms


def build_discriminator(type, **kwargs):
#    discriminator_type = getattr(models, type)

#    if 'Categorical' not in type and 'dim_categorical' in kwargs:
#        kwargs.pop('dim_categorical')PatchImageDiscriminator
  if type=="PatchImageDiscriminator":
    return PatchImageDiscriminator(**kwargs)#discriminator_type(**kwargs)
  if type=="CategoricalVideoDiscriminator":
    return CategoricalVideoDiscriminator(**kwargs)#discriminator_type(**kwargs)


def video_transform(video, image_transform):
    vid = []
    for im in video:
        vid.append(image_transform(im))

    vid = torch.stack(vid).permute(1, 0, 2, 3)

    return vid


img_size=64
video_length = 16
image_batch = 10
video_batch = 3

dim_z_content = 30
dim_z_motion = 10
dim_z_category = 4
print_every=1000
batches=100000
log_folder="./"
use_infogan=0
use_categories=3
use_noise=0
noise_sigma = 0

image_discriminator="PatchImageDiscriminator"
video_discriminator="CategoricalVideoDiscriminator"

n_channels = 3
dataset="mocogan/data/actions"


def select_channels(x):
  return x[:n_channels, ::]


In [None]:
image_transforms = transforms.Compose([
    PIL.Image.fromarray,
    transforms.Resize(img_size),
    transforms.ToTensor(),
    select_channels,
    transforms.Normalize((0.5, 0.5, .5), (0.5, 0.5, 0.5)),
])

video_transforms = functools.partial(video_transform, image_transform=image_transforms)

dataset = VideoFolderDataset(dataset, cache=os.path.join(dataset, 'local.db'))
image_dataset = ImageDataset(dataset, image_transforms)
image_loader = DataLoader(image_dataset, batch_size=image_batch, drop_last=True, num_workers=2, shuffle=True)

video_dataset = VideoDataset(dataset, 16, 2, video_transforms)
video_loader = DataLoader(video_dataset, batch_size=video_batch, drop_last=True, num_workers=2, shuffle=True)

generator = VideoGenerator(n_channels, dim_z_content, dim_z_category, dim_z_motion, video_length)

image_discriminator = build_discriminator(
    image_discriminator, n_channels=n_channels,
    use_noise=use_noise, noise_sigma=noise_sigma
)

video_discriminator = build_discriminator(
    video_discriminator, dim_categorical=dim_z_category,
    n_channels=n_channels, use_noise=use_noise,
    noise_sigma=noise_sigma
)

if torch.cuda.is_available():
    generator.cuda()
    image_discriminator.cuda()
    video_discriminator.cuda()

trainer = Trainer(
    image_loader,
    video_loader,
    print_every,
    batches,
    log_folder,
    use_cuda=torch.cuda.is_available(),
    use_infogan=use_infogan,
    use_categories=use_categories
)


Total number of frames 5290


In [None]:
num_epochs = 10
trainer.train(generator, image_discriminator, video_discriminator, num_epochs)


Epoch 1/10:   1%|          | 569/100000 [1:00:00<175:06:35,  6.34s/batch]

## Generate videos

In [None]:
!apt-get install ffmpeg

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.


In [None]:
import os
import torch

import subprocess as sp

model = "generator_20000.pytorch"
num_videos = 10
output_folder = "./result"
number_of_frames = 16
ffmpeg = "ffmpeg"
output_format = "gif"


def save_video(ffmpeg, video, filename):
    command = [ffmpeg,
               '-y',
               '-f', 'rawvideo',
               '-vcodec', 'rawvideo',
               '-s', '64x64',
               '-pix_fmt', 'rgb24',
               '-r', '8',
               '-i', '-',
               '-c:v', 'gif',
               filename]

    pipe = sp.Popen(command, stdin=sp.PIPE, stderr=sp.PIPE, bufsize=0)
    pipe.stdin.write(video.tobytes())


generator = torch.load(model, map_location={'cuda:0': 'cpu'}, weights_only=False)
generator.eval()

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

for i in range(num_videos):
    v, _ = generator.sample_videos(1, number_of_frames)
    video = videos_to_numpy(v).squeeze().transpose((1, 2, 3, 0))
    save_video(ffmpeg, video, os.path.join(output_folder, "{}.{}".format(i, output_format)))


FileNotFoundError: [Errno 2] No such file or directory: 'generator_20000.pytorch'