# **Dffusion model on CIFAR-10**

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings).

The model is a denoising diffusion probabilistic model (https://arxiv.org/abs/2006.11239), which is trained to reverse a gradual noising process, allowing the model to generate samples from the learned data distribution starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. This model is also trained on continuous timesteps parameterized by the log SNR on each timestep (see Variational Diffusion Models, https://arxiv.org/abs/2107.00630), allowing different noise schedules than the one used during training to be easily used during sampling. It uses the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI) for better conditioned denoised images at high noise levels, but reweights the loss function so that it has the same relative weighting as the 'eps' objective.

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
# Check the GPU type
!nvidia-smi

## **Library imports**

In [None]:
!pip install PyDrive

In [None]:
import math
import numpy as np
from IPython import display
from matplotlib import pyplot as plt
from contextlib import contextmanager
from copy import deepcopy
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt


from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm, trange

In [None]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

## **Utility functions**

In [None]:
@contextmanager
def train_mode(model, mode=True):
    """A context manager that places a model into training mode and restores
    the previous mode on exit."""
    modes = [module.training for module in model.modules()]
    try:
        yield model.train(mode)
    finally:
        for i, module in enumerate(model.modules()):
            module.training = modes[i]
def eval_mode(model):
    """A context manager that places a model into evaluation mode and restores
    the previous mode on exit."""
    return train_mode(model, False)


@torch.no_grad()
def ema_update(model, averaged_model, decay):
    """Incorporates updated model parameters into an exponential moving averaged
    version of a model. It should be called after each optimizer step."""
    model_params = dict(model.named_parameters())
    averaged_params = dict(averaged_model.named_parameters())
    assert model_params.keys() == averaged_params.keys()

    for name, param in model_params.items():
        averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)

    model_buffers = dict(model.named_buffers())
    averaged_buffers = dict(averaged_model.named_buffers())
    assert model_buffers.keys() == averaged_buffers.keys()

    for name, buf in model_buffers.items():
        averaged_buffers[name].copy_(buf)

## **Model definition**
residual U-Net

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return self.main(input) + self.skip(input)


class ResConvBlock(ResidualBlock):
    def __init__(self, c_in, c_mid, c_out, dropout_last=True):
        skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
        super().__init__([
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(c_mid, c_out, 3, padding=1),
            nn.ReLU(inplace=True),
        ], skip)

In [None]:
class SkipBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return torch.cat([self.main(input), self.skip(input)], dim=1)


class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)


def expand_to_planes(input, shape):
    return input[..., None, None].repeat([1, 1, shape[2], shape[3]])

In [None]:
class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64  # The base channel count
        # The inputs to timestep_embed will approximately fall into the range
        # -10 to 10, so use std 0.2 for the Fourier Features.
        self.timestep_embed = FourierFeatures(1, 16, std=0.2)
        self.net = nn.Sequential(   # 32x32
            ResConvBlock(3 + 16 + 4, c, c),
            ResConvBlock(c, c, c),
            SkipBlock([
                nn.AvgPool2d(2),  # 32x32 -> 16x16
                ResConvBlock(c, c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c * 2),
                SkipBlock([
                    nn.AvgPool2d(2),  # 16x16 -> 8x8
                    ResConvBlock(c * 2, c * 4, c * 4),
                    ResConvBlock(c * 4, c * 4, c * 4),
                    SkipBlock([
                        nn.AvgPool2d(2),  # 8x8 -> 4x4
                        ResConvBlock(c * 4, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 8),
                        ResConvBlock(c * 8, c * 8, c * 4),
                        nn.Upsample(scale_factor=2),
                    ]),  # 4x4 -> 8x8
                    ResConvBlock(c * 8, c * 4, c * 4),
                    ResConvBlock(c * 4, c * 4, c * 2),
                    nn.Upsample(scale_factor=2),
                ]),  # 8x8 -> 16x16
                ResConvBlock(c * 4, c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c),
                nn.Upsample(scale_factor=2),
            ]),  # 16x16 -> 32x32
            ResConvBlock(c * 2, c, c),
            ResConvBlock(c, c, 3, dropout_last=False),
        )

    def forward(self, input, log_snrs, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(log_snrs[:, None]), input.shape)
        b,c,h,w = input.shape
        class_embed = torch.zeros(b,4,h,w).to(device)
        return self.net(torch.cat([input, class_embed, timestep_embed], dim=1))


    def get_features(self, input, log_snrs, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(log_snrs[:, None]), input.shape)
        b, c, h, w = input.shape
        class_embed = torch.zeros(b,4,h,w).to(device)
        x = torch.cat([input, class_embed, timestep_embed], dim=1)

        features = []
        features_before_up = []
        res_bl_lvl1_num = 0
        for module in self.net:

            if isinstance(module, ResConvBlock):
                x = module(x)
                features.append(x)
                res_bl_lvl1_num += 1
                if res_bl_lvl1_num == 3:
                    features_before_up.append(x)

            if isinstance(module, SkipBlock):
                before_skip1 = x

                for module1 in module.main:

                    if isinstance(module1, nn.AvgPool2d):
                        x = module1(x)

                    if isinstance(module1, ResConvBlock):
                        x = module1(x)
                        features.append(x)

                    if isinstance(module1, SkipBlock):
                        before_skip2 = x

                        for module2 in module1.main:

                            if isinstance(module2, nn.AvgPool2d):
                                x = module2(x)

                            if isinstance(module2, ResConvBlock):
                                x = module2(x)
                                features.append(x)

                            if isinstance(module2, SkipBlock):
                                before_skip3 = x

                                for module3 in module2.main:

                                    if isinstance(module3, nn.AvgPool2d):
                                        x = module3(x)

                                    if isinstance(module3, ResConvBlock):
                                        x = module3(x)
                                        features.append(x)

                                    if isinstance(module3, nn.Upsample):
                                        features_before_up.append(x)
                                        x = module3(x)
                                        x = torch.cat([x, before_skip3], dim=1)

                            if isinstance(module2, nn.Upsample):
                                features_before_up.append(x)
                                x = module2(x)
                                x = torch.cat([x, before_skip2], dim=1)

                    if isinstance(module1, nn.Upsample):
                        features_before_up.append(x)
                        x = module1(x)
                        x = torch.cat([x,before_skip1], dim=1)

            if isinstance(module, nn.Upsample):
                features_before_up.append(x)
                x = module(x)

        features = [feature.squeeze().cpu().numpy() for feature in features_before_up]
        feature0 = np.array(features[0])
        feature1 = np.array(features[1])
        feature2 = np.array(features[2])
        feature3 = np.array(features[3])
        feature0_pooled = F.avg_pool2d(torch.tensor(feature0), kernel_size=4).squeeze()
        feature1_pooled = F.avg_pool2d(torch.tensor(feature1), kernel_size=8).squeeze()
        feature2_pooled = F.avg_pool2d(torch.tensor(feature2), kernel_size=16).squeeze()
        feature3_pooled = F.avg_pool2d(torch.tensor(feature3), kernel_size=32).squeeze()
        feature_map = torch.cat((feature0_pooled, feature1_pooled, feature2_pooled, feature3_pooled), axis =1)
        return feature_map

## **Sample**

In [None]:
def get_alphas_sigmas(log_snrs):
    """Returns the scaling factors for the clean image (alpha) and for the
    noise (sigma), given the log SNR for a timestep."""
    return log_snrs.sigmoid().sqrt(), log_snrs.neg().sigmoid().sqrt()

def get_ddpm_schedule(t):
    """Returns log SNRs for the noise schedule from the DDPM paper."""
    return -torch.special.expm1(1e-4 + 10 * t**2).log()


@torch.no_grad()
def sample(model, x, steps, eta, classes):
    """Draws samples from a model given starting noise."""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    t = torch.linspace(1, 0, steps + 1)[:-1]
    log_snrs = get_ddpm_schedule(t)
    alphas, sigmas = get_alphas_sigmas(log_snrs)

    # The sampling loop
    for i in trange(steps):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            v = model(x, ts * log_snrs[i], classes).float()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < steps - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma
            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma
    # If we are on the last timestep, output the denoised image
    return pred

## **Noise schedule**

In [None]:
import torch
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.dpi'] = 100

t_vis = torch.linspace(0, 1, 1000)
log_snrs_vis = get_ddpm_schedule(t_vis)
alphas_vis, sigmas_vis = get_alphas_sigmas(log_snrs_vis)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].plot(t_vis, alphas_vis, label='alpha (signal level)')
axes[0].plot(t_vis, sigmas_vis, label='sigma (noise level)')
axes[0].legend()
axes[0].set_xlabel('timestep')
axes[0].grid()

axes[1].plot(t_vis, log_snrs_vis, label='log SNR')
axes[1].legend()
axes[1].set_xlabel('timestep')
axes[1].grid()

plt.tight_layout()
plt.show()


## **Dataset**

In [None]:
batch_size = 32

# tf = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize([0.5], [0.5]),
# ])

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
])


train_set = datasets.CIFAR100('data', train=True, download=True, transform=tf)
val_set = datasets.CIFAR100('data', train=False, download=True, transform=tf)

train_dl = data.DataLoader(train_set, batch_size, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True)
val_dl = data.DataLoader(val_set, batch_size, num_workers=4, persistent_workers=True, pin_memory=True)

## **Define parameters**

In [None]:
seed = 0
epoch = 0
ema_decay = 0.998

# The number of timesteps to use when sampling
steps = 500

# The amount of noise to add each timestep when sampling
# 0 = no noise (DDIM)
# 1 = full noise (DDPM)
eta = 1.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(seed)

model = Diffusion().to(device)
model_ema = deepcopy(model)
print('Model parameters:', sum(p.numel() for p in model.parameters()))

In [None]:
opt = optim.Adam(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler()

# Use a low discrepancy quasi-random sequence to sample uniformly distributed
# timesteps. This considerably reduces the between-batch variance of the loss.
rng = torch.quasirandom.SobolEngine(1, scramble=True)

## **Training**

In [None]:
def eval_loss(model, rng, reals, classes):
    # Draw uniformly distributed continuous timesteps
    t = rng.draw(reals.shape[0])[:, 0].to(device)

    # Calculate the noise schedule parameters for those timesteps
    log_snrs = get_ddpm_schedule(t)
    alphas, sigmas = get_alphas_sigmas(log_snrs)
    weights = log_snrs.exp() / log_snrs.exp().add(1)

    # Combine the ground truth images and the noise
    alphas = alphas[:, None, None, None]
    sigmas = sigmas[:, None, None, None]
    noise = torch.randn_like(reals)
    noised_reals = reals * alphas + noise * sigmas
    targets = noise * alphas - reals * sigmas

    # Compute the model output and the loss.
    with torch.cuda.amp.autocast():
        v = model(noised_reals, log_snrs, classes)
        return (v - targets).pow(2).mean([1, 2, 3]).mul(weights).mean()


def train():
    for i, (reals, classes) in enumerate(tqdm(train_dl)):
        opt.zero_grad()
        reals = reals.to(device)
        classes = classes.to(device)

        # Evaluate the loss
        loss = eval_loss(model, rng, reals, classes)

        # Do the optimizer step and EMA update
        scaler.scale(loss).backward()
        scaler.step(opt)
        ema_update(model, model_ema, 0.95 if epoch < 20 else ema_decay)
        scaler.update()

        if i % 50 == 0:
            tqdm.write(f'Epoch: {epoch}, iteration: {i}, loss: {loss.item():g}')


@torch.no_grad()
@torch.random.fork_rng()
@eval_mode(model_ema)
def val():
    tqdm.write('\nValidating...')
    torch.manual_seed(seed)
    rng = torch.quasirandom.SobolEngine(1, scramble=True)
    total_loss = 0
    count = 0
    for i, (reals, classes) in enumerate(tqdm(val_dl)):
        reals = reals.to(device)
        classes = classes.to(device)

        loss = eval_loss(model_ema, rng, reals, classes)

        total_loss += loss.item() * len(reals)
        count += len(reals)
    loss = total_loss / count
    tqdm.write(f'Validation: Epoch: {epoch}, loss: {loss:g}')


@torch.no_grad()
@torch.random.fork_rng()
@eval_mode(model_ema)
def demo():
    tqdm.write('\nSampling...')
    torch.manual_seed(seed)

    noise = torch.randn([100, 3, 32, 32], device=device)
    fakes_classes = torch.arange(10, device=device).repeat_interleave(10, 0)
    fakes = sample(model_ema, noise, steps, eta, fakes_classes)

    grid = utils.make_grid(fakes, 10).cpu()
    filename = f'demo_{epoch:05}.png'
    TF.to_pil_image(grid.add(1).div(2).clamp(0, 1)).save(filename)
    display.display(display.Image(filename))
    tqdm.write('')


def save():
    filename = 'cifar_diffusion.pth'
    obj = {
        'model': model.state_dict(),
        'model_ema': model_ema.state_dict(),
        'opt': opt.state_dict(),
        'scaler': scaler.state_dict(),
        'epoch': epoch,
    }
    torch.save(obj, filename)


try:
    val()
    demo()
    while True:
        print('Epoch', epoch)
        train()
        epoch += 1
        if epoch % 5 == 0:
            val()
            demo()
        save()
except KeyboardInterrupt:
    pass

## **Save model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import shutil

source_path = '/content/cifar_diffusion.pth'
destination_path = '/content/drive/MyDrive/cifar100_diffusion.pth'

shutil.copyfile(source_path, destination_path)

## **Load model**

In [None]:
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
file_id = '1GJy9fHFPEt3acuAcV0syubKJOeCZaWnv' # URL id
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('cifar100_diffusion.pth')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

torch.manual_seed(0)

weights_path = '/content/cifar100_diffusion.pth'
saved_obj = torch.load(weights_path, map_location=torch.device('cpu'))
model_dif = Diffusion().to(device)
model_dif.load_state_dict(saved_obj['model'])

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print('Using device:', device)

# torch.manual_seed(0)

# weights_path = '/content/drive/MyDrive/cifar100_diffusion.pth'
# saved_obj = torch.load(weights_path, map_location=torch.device('cpu'))
# model_dif = Diffusion().to(device)
# model_dif.load_state_dict(saved_obj['model'])

## **Umap feature vizualization**

In [None]:
!pip install umap-learn

In [None]:
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
import umap
from mpl_toolkits.mplot3d import Axes3D

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
])


dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
sampler = SubsetRandomSampler(range(10000))
dataloader_small = DataLoader(dataset, batch_size=1000, sampler=sampler)

model_dif.eval()
model_dif.to(device)

features_list = []
labels_list = []

for images, labels in dataloader_small:
    images = images.to(device)
    labels = labels.to(device)

    t_up = torch.tensor([0.001] * len(labels)).to(device)
    log_snrs = get_ddpm_schedule(t_up)

    with torch.no_grad():
        log_snrs = log_snrs.to(device)
        features = model_dif.get_features(images, log_snrs, labels).to(device)

    features_list.append(features.cpu().detach().numpy())
    labels_list.append(labels.cpu().detach().numpy())

features_array = np.concatenate(features_list, axis=0)
labels_array = np.concatenate(labels_list, axis=0)

In [None]:
features_array.shape

In [None]:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(features_array)

reducer = umap.UMAP(n_components=2, random_state=1)
X_umap = reducer.fit_transform(X_scaled)

In [None]:
colors = ['#0d0887',
 '#46039f',
 '#7201a8',
 '#9c179e',
 '#bd3786',
 '#d8576b',
 '#ed7953',
 '#fb9f3a',
 '#fdca26',
 '#f0f921']

plt.figure(figsize=(12, 10))
for i in range(10):
    plt.scatter(X_umap[labels_array == i, 0], X_umap[labels_array == i, 1], color=colors[i], label=str(i), s=10)
plt.xlabel('UMAP Component 1', fontsize=14)
plt.ylabel('UMAP Component 2', fontsize=14)
plt.title('2D Projection of diffusion Features using UMAP', fontsize=16)
plt.legend(title='Digit', loc='upper right')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig('umap_projection.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
np.save('x_umap', np.array(X_umap))

In [None]:
np.save('labels_array', np.array(labels_array))

### **3D-projection**

In [None]:
reducer = umap.UMAP(n_components=3, n_neighbors=15, min_dist=0.1, random_state=42)
features_3d = reducer.fit_transform(features_array)

fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
for i in range(10):
    plt.scatter(X_umap[labels_array == i, 0], X_umap[labels_array == i, 1], color=colors[i], label=str(i), s=5)

ax.set_xlabel('UMAP Component 1', fontsize=14)
ax.set_ylabel('UMAP Component 2', fontsize=14)
ax.set_zlabel('UMAP Component 3', fontsize=14)

ax.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig('umap_3d_projection_mnist.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
np.save('features_3d', np.array(features_3d))
np.save('labels_array', np.array(labels_array))

## **Small Net training**

In [None]:
def extract_features(images, labels, t_up, batch_size, model_dif):
    """
    Extracts features, namely concatenates averaged arrays from different layers of the diffusion model with the UNet architecture.
    :t_up: from which step of the forward diffusion process image is needed.
    :model_dif: a diffusion model.
    :forward_diffusion: forward diffusion process.
    :return: a tensor of shape [batch_size, 80].
    """
    with torch.no_grad():
        images = images.to(device)
        labels = labels.to(device)
        t = torch.tensor([t_up] * labels.shape[0])
        log_snrs = get_ddpm_schedule(t)
        log_snrs = log_snrs.to(device)
        alphas, sigmas = get_alphas_sigmas(log_snrs)
        weights = log_snrs.exp() / log_snrs.exp().add(1)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        noise = torch.randn_like(images)
        noised_reals = images * alphas + noise * sigmas
        features = model_dif.get_features(noised_reals.to(device), log_snrs, labels)
    return torch.tensor(features)

In [None]:
def train_model(train_loader, t_up, batch_size, model, criterion, optimizer, model_dif,  epochs=90, loss_list=[]):
    model.train()
    loss_list = []
    for epoch in tqdm(range(1, epochs+1)):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            features = extract_features(images, labels, t_up, batch_size, model_dif.to(device)).to(device)
            outputs = model(features)
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_list.append(running_loss / len(train_loader))
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.3f}")
    return loss_list


def test_model(model, test_loader, t_up, batch_size, model_dif=model_dif):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)
            features = extract_features(images, labels, t_up, batch_size, model_dif).to(device)
            outputs = model(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
])
train_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

n_samples = 16
batch_size = 4
n_classes = 100

class_indices = [[] for _ in range(n_classes)]
for idx, (_, label) in enumerate(train_dataset):
    class_indices[label].append(idx)

selected_indices = []
for indices in class_indices:
    selected_indices.extend(indices[:n_samples])

sampler = SubsetRandomSampler(selected_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def split(dataset, num_train_per_class):
    train_indices = []
    for i in range(100):
        indices = torch.where(torch.tensor(dataset.targets) == i)[0].tolist()
        train_indices.extend(indices[:num_train_per_class])
    return train_indices

class Net(nn.Module):
    def __init__(self, input_size, num_classes, hidden_dim=128):
        super(Net, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
        self.input_size = input_size
        self.num_classes = num_classes

    def forward(self, x):
        return self.model(x)

In [None]:
#from smallnet import LinearNet, Net, split_dataset

input_size = 512

In [None]:
epochs = 10
#t_ups = [0.01, 0.03, 0.05, 0.1, 0.3, 0.45, 0.55, 1]
t_ups = [0.01, 0.03, 0.05, 0.07, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 0.99]


accuracies = {}
for t_up in t_ups:
    print(f'--------- CALCULATIONS FOR t_up={t_up} ---------')
    accuracies[t_up] = []

    for i in range(1, 6):
        print(f'------ calc N=#{i} ------')

        model = Net(input_size, n_classes).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        loss_list = train_model(train_loader, t_up, batch_size, model, criterion, optimizer, model_dif, epochs=epochs)
        accuracy = test_model(model, test_loader, t_up, batch_size, model_dif=model_dif)
        print(f'accuracy = {accuracy}%')
        accuracies[t_up].append(accuracy)

    print(f'for {t_up}, accuracies=[' + ', '.join(list(map(str, accuracies[t_up]))) + ']')


In [None]:
accuracies

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm

accuracy = {k: v for k, v in accuracies.items() if k not in [0.99, 0.46, 0.48, 0.52, 0.54, 0.06]}

accuracy_transformed = {int(k * 1000): v for k, v in accuracy.items()}
sorted_accuracy = dict(sorted(accuracy_transformed.items()))

t_up_values = list(sorted_accuracy.keys())
means = [np.mean(sorted_accuracy[t_up]) for t_up in t_up_values]
variances = [np.var(sorted_accuracy[t_up]) for t_up in t_up_values]

plt.figure(figsize=(8, 8))
plt.plot(t_up_values, means, marker='o', linestyle='-', color='blue', markersize=8, markerfacecolor='red')
plt.fill_between(t_up_values, np.array(means) - np.array(variances), np.array(means) + np.array(variances), color='blue', alpha=0.2)

plt.axvline(x=50, color='black')
plt.text(110, 25, r'$t_{opt} = 50$', fontsize=12, verticalalignment='bottom', horizontalalignment='center', fontweight='heavy')

plt.xticks(t_up_values, t_up_values, rotation=45, fontsize=9)
plt.xlabel('t_up')
plt.ylabel('Mean Accuracy on test, %')
plt.title('Optimal timestep selection for CIFAR-100')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()

plt.show()

# Dataset size varying

In [None]:
from torch.utils.data import Subset
from torch.utils.data import DataLoader

In [None]:
NUM_WORKERS = 0

tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
])
train_dataset = datasets.CIFAR100('data', train=True, download=True, transform=tf)
test_dataset = datasets.CIFAR100('data', train=False, download=True, transform=tf)
accuracy_list = list()
# TRAIN_CLASS_SIZE_list = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
# BATCH_SIZE_list =       [2, 4, 8, 16, 32, 64, 128, 128, 128, 128]
# TRAIN_CLASS_SIZE_list = [16, 32, 64, 128, 256, 512, 1024]
# BATCH_SIZE_list =       [16, 32, 64, 128, 128, 128, 128]
TRAIN_CLASS_SIZE_list = [256]
BATCH_SIZE_list =       [128]
for i in range(len(TRAIN_CLASS_SIZE_list)):
    print("____________________________________________")
    print("train class size:", TRAIN_CLASS_SIZE_list[i])

    BATCH_SIZE = BATCH_SIZE_list[i] #128
    TRAIN_CLASS_SIZE = TRAIN_CLASS_SIZE_list[i]

    #train_dataset = datasets.MNIST('data', train=True, download=True, transform=tf)

    train_indices = split(train_dataset, num_train_per_class=TRAIN_CLASS_SIZE)
    train_subset = Subset(train_dataset, train_indices)
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)


    #test_dataset = datasets.MNIST('data', train=False, download=True, transform=tf)

    test_subset = test_dataset
    test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE)
    # print(len(train_loader))
    #from smallnet import LinearNet, Net, split_dataset
    input_size = 512  # embedding size
    num_classes = 100
    epochs = 100
    t_up = 0.05  # from 0 to 1!!!
    model = Net(input_size, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    loss_list = train_model(train_loader, t_up, BATCH_SIZE, model, criterion, optimizer, model_dif, epochs=epochs)
    accuracy = test_model(model, test_loader, t_up, BATCH_SIZE, model_dif=model_dif)
    print(accuracy)
    accuracy_list.append(accuracy)