<font face='monospace'>
<h2><b>Infernece Model</b></h2>

We use this model to check whether the ouput images of diffuion model have data distribution similar to that of the original dataset 

In [None]:
%pip install -qU fastai fastcore accelerate einops datasets torcheval matplotlib scipy numpy torch

In [None]:
import torch
import logging
import fastcore.all as fc
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from torch import nn, optim
from pathlib import Path
from diffusion_ai import *
from functools import partial
from datasets import load_dataset
from torchvision import transforms
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy

In [None]:
# Disable warnings and set random seeds for reproducibility
logging.disable(logging.WARNING)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
set_seed(42)

In [None]:
# Create model path directory if it doesn't exist
model_path = Path('models')
model_path.mkdir(exist_ok=True)

In [None]:
# Define constants
IMAGE_KEY, LABEL_KEY = 'image', 'label'
DATASET_NAME = "fashion_mnist"
BATCH_SIZE = 1024
X_MEAN, X_STD = 0.28, 0.35

In [None]:
@inplace
def transformi(batch):
    """
    Normalize the images in the batch.
    
    Args:
        batch (dict): Batch of data containing images and labels.
    """
    batch[IMAGE_KEY] = [(TF.to_tensor(img) - X_MEAN)/X_STD for img in batch[IMAGE_KEY]]


In [None]:
# Load and transform dataset
dataset = load_dataset(DATASET_NAME, trust_remote_code=True)
transformed_dataset = dataset.with_transform(transformi)
data_loaders = DataLoaders.from_dd(transformed_dataset, BATCH_SIZE, num_workers=2)

In [None]:
def create_model(activation_fn=nn.ReLU, filters=(16, 32, 64, 128, 256, 512), norm_layer=nn.BatchNorm2d):
    """
    Create a CNN model with residual blocks.

    Args:
        activation_fn (callable): Activation function to use in the model.
        filters (tuple): Number of filters for each layer.
        norm_layer (callable): Normalization layer to use.

    Returns:
        nn.Sequential: Constructed model.
    """
    layers = [ResBlock(1, 16, ks=5, stride=1, act=activation_fn, norm=norm_layer)]
    layers += [ResBlock(filters[i], filters[i + 1], act=activation_fn, norm=norm_layer, stride=2) for i in range(len(filters) - 1)]
    layers += [nn.Flatten(), nn.Linear(filters[-1], 10, bias=False), nn.BatchNorm1d(10)]
    return nn.Sequential(*layers)

In [None]:
# Initialize metrics, callbacks, and other configurations
metrics_cb = MetricsCB(accuracy=MulticlassAccuracy())
activation_stats_cb = ActivationStats(fc.risinstance(GeneralRelu))
callbacks = [DeviceCB(), metrics_cb, ProgressCB(plot=True), activation_stats_cb]
activation_general_relu = partial(GeneralRelu, leak=0.1, sub=0.4)
initialize_weights = partial(init_weights, leaky=0.1)

In [None]:
def transform_batch_elements(batch, transform_x=fc.noop, transform_y=fc.noop):
    """
    Apply transformations to the elements of a batch.

    Args:
        batch (tuple): Batch of data (inputs, targets).
        transform_x (callable): Transformation function for inputs.
        transform_y (callable): Transformation function for targets.

    Returns:
        tuple: Transformed inputs and targets.
    """
    return transform_x(batch[0]), transform_y(batch[1])


In [None]:
# Define data augmentations
data_augmentations = nn.Sequential(
    transforms.RandomCrop(28, padding=1),
    transforms.RandomHorizontalFlip(),
    RandCopy()  # or use RandErase()
)
augmentation_cb = BatchTransformCB(partial(transform_batch_elements, transform_x=data_augmentations), on_val=False)

In [None]:
# Training configurations
EPOCHS = 1
LEARNING_RATE = 1e-2
TOTAL_STEPS = EPOCHS * len(data_loaders.train)
scheduler = partial(lr_scheduler.OneCycleLR, max_lr=LEARNING_RATE, total_steps=TOTAL_STEPS)
extra_callbacks = [BatchSchedCB(scheduler), augmentation_cb]

# Create and initialize the model
model = create_model(activation_fn=activation_general_relu, norm_layer=nn.BatchNorm2d).apply(initialize_weights)
learner = TrainLearner(
    model, data_loaders, F.cross_entropy, lr=LEARNING_RATE,
    cbs=callbacks + extra_callbacks, opt_func=optim.AdamW
)

In [None]:
# Training
learner.fit(EPOCHS)

In [None]:
torch.save(learner.model, model_path/'inference.pkl')

In [None]:
# To free RAM, space
import gc
from IPython.display import clear_output

gc.collect()
clear_output(wait=True)

In [None]:
%reset -f