In [1]:
import torch as tch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = 'cpu'
# batch_size=

In [3]:
n_directions = 6
direction_embedding_dimension = 32
time_embedding_dimension = 32

##dataset
# each databatch is has dimensions 10000x3072
# I should reshape it to (10000,3,32,32)

In [4]:
# class TimeEmbedding(nn.Module):
#     def __init__(self, embed_dim: int):
#         super().__init__()
#         self.embed_dim = embed_dim

#     def forward(self, t):
#         # t: (batch_size,) - the timestep
#         # Create the sinusoidal embedding
#         half_dim = self.embed_dim // 2
#         exponents = torch.arange(half_dim, dtype=torch.float32) / half_dim
#         freqs = torch.pow(10000, -exponents).to(t.device)
#         angles = t[:, None] * freqs  # Broadcasting over the batch dimension
#         # Combine sine and cosine
#         time_embedding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
#         return time_embedding  # Shape: (batch_size, embed_dim)


In [5]:
#define architecture and U-net
# Unet is predicting the noise.
#  I have image x. What image might I get if I denoise by time "timestep" in the past in given direction.


# for now the same as direction embedding
class TimeEmbedding(tch.nn.Module):
    def __init__(self, time_embedding_dimension):
        super().__init__()
        self.time_embedding = tch.nn.Linear(1, time_embedding_dimension)
    def forward(self, timestep):
        
        timestep = timestep.view(-1,1).float()
        return self.time_embedding(timestep)

class DirectionEmbedding(tch.nn.Module):
    def __init__(self, n_classes, direction_embedding_dimension):
        super().__init__()
        self.direction_embedding = tch.nn.Embedding(n_classes, direction_embedding_dimension)
    def forward(self, class_label):
         return self.direction_embedding(class_label)

# a block in my modified UNet
class Block(tch.nn.Module):
    def __init__(self, in_channels, out_channels, time_embedding_dimension, direction_embedding_dimension):
        super().__init__()
        self.conv = tch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm = tch.nn.BatchNorm2d(out_channels)
        self.activation = tch.nn.ReLU()
        # rescale time_embedding and direction_embedding to match the dimension of the channels
        self.time_embedding_projection = tch.nn.Linear(time_embedding_dimension, in_channels)
        self.direction_embedding_projection = tch.nn.Linear(direction_embedding_dimension, in_channels)
    def forward(self, x, time_embedding, direction_embedding):
        
        batch_size, n_channels, height, width = x.shape
        # why -1 instead of n_channels?
        #  should be broadcastable to x.
        time_embedding = self.time_embedding_projection(time_embedding).view(batch_size,-1,1,1)
        direction_embedding = self.direction_embedding_projection(direction_embedding).view(batch_size,-1,1,1)

        # adding time embedding to input
        x = x+time_embedding+direction_embedding
        
        # forward pass
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x
    


In [6]:
class modifiedUnet(tch.nn.Module):
    def __init__(self, in_channels, out_channels, time_embedding_dimension, direction_embedding_dimension, n_classes):
        super().__init__()

        self.time_embedding = TimeEmbedding(time_embedding_dimension)
        self.direction_embedding = DirectionEmbedding(n_classes, direction_embedding_dimension)

        self.layer1 = Block(in_channels, 64, time_embedding_dimension,direction_embedding_dimension)
        self.layer2 = Block(64, 128, time_embedding_dimension,direction_embedding_dimension)
        self.layer3 = Block(128, 64, time_embedding_dimension,direction_embedding_dimension)
        self.layer4 = Block(64, out_channels, time_embedding_dimension,direction_embedding_dimension)

    def forward(self, x, class_label, timestep):
        time_embedding = self.time_embedding(timestep)
        direction_embedding = self.direction_embedding(class_label)

        # no skip connections for now
        x1 = self.layer1(x,time_embedding,direction_embedding)
        x2 =  self.layer2(x1,time_embedding,direction_embedding)
        x3 =  self.layer3(x2,time_embedding,direction_embedding)
        x4 =  self.layer4(x3,time_embedding,direction_embedding)
        return x4



In [7]:
#weights, loss function, optimizer
# Adam
# Instantiate the model
model = modifiedUnet(in_channels=1, out_channels=1, time_embedding_dimension=16, direction_embedding_dimension=16, n_classes=10)

# Dummy inputs
x = tch.randn(1, 1, 28, 28)  # Batch size = 1, Channels = 1, Height = 28, Width = 28
class_label = tch.tensor([3])  # Class index for the direction embedding
timestep = tch.tensor([5.0])  # Scalar timestep

# Forward pass
output = model(x, class_label, timestep)
print(output.shape)  # Should match the shape defined by the last Block (e.g., [1, 1, 28, 28] if out_channels=1)``
print(output)

torch.Size([1, 1, 28, 28])
tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 1.0426e+00, 8.3844e-01],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 3.9586e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00, 2.3074e-01, 1.7422e+00],
          [0.0000e+00, 0.0000e+00, 7.4985e-01, 6.2183e-01, 4.7046e-01,
           0.0000e+00, 7.4847e-01, 8.1501e-01, 0.0000e+00, 8.4364e-01,
           2.6670e-01, 6.