In [1]:
import torch
import sys


sys.path.append("/home/jovyan/beomi/ayushraina/ddpm_pipeline")

In [6]:
from models.time_embedding import SinusoidalTimeEmbedding
from models.convBlockForUnet import ConvBlockForUnet

def testConvBlockWithoutEmbeddings():
    
    # Generate a random batch of 16*16 RGB
    batchSize, Channels, Height, Width = 2,3,16,16
    x = torch.randn(batchSize, Channels, Height, Width)
    outChannels = 6
    
    # Forward Pass
    conv_block = ConvBlockForUnet(inChannels=3, outChannels=outChannels)
    output = conv_block.forward(x)
    
    # Check Output
    expectedShape = (batchSize, outChannels, Height, Width)
    assert output.shape == expectedShape, "Expected Shape is not Equal to Output Shape"
    
    print(f"Input Shape: {x.shape}")
    print(f"Output Shape: {output.shape}")
    print("Working Fine without Embeddings")
    
    return True
    
testConvBlockWithoutEmbeddings()

Input Shape: torch.Size([2, 3, 16, 16])
Output Shape: torch.Size([2, 6, 16, 16])
Working Fine without Embeddings


True

In [10]:
def testConvBlockWithEmbeddings():
    
    # Generate a random batch of 16*16 RGB
    batchSize, Channels, Height, Width = 2,3,16,16
    x = torch.randn(batchSize, Channels, Height, Width)
    outChannels = 8
    
    timeEmbeddingDimension = 32
    embedding_generator = SinusoidalTimeEmbedding(timeEmbeddingDimension)
    
    timesteps = torch.tensor([100,500])
    time_embedding = embedding_generator.forward(timesteps)
    
    # Forward Pass
    convBlock = ConvBlockForUnet(inChannels=3, outChannels=outChannels, timeEmbeddingDimension=timeEmbeddingDimension) # In this line also time dimension is set only
    output = convBlock.forward(x, time_embedding)
    
    # Check Output
    expectedShape = (batchSize, outChannels, Height, Width)
    assert output.shape == expectedShape, "Expected Shape is not equal to Output Shape"
    
    print(f"Input Shape: {x.shape}") 
    print(f"Time Embedding Shape: {time_embedding.shape}")
    print(f"Output Shape: {output.shape}") 
    print(f"Working fine with embeddings too")  
    
testConvBlockWithEmbeddings()

Input Shape: torch.Size([2, 3, 16, 16])
Time Embedding Shape: torch.Size([2, 32])
Output Shape: torch.Size([2, 8, 16, 16])
Working fine with embeddings too
