In [1]:
import torch
import torch.nn as nn

class SimplifiedCrossCondTransBase(nn.Module):
    
    def __init__(self, num_vq=1024, embed_dim=512, block_size=16):
        super().__init__()
        self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)  # Token Embedding
        self.cond_emb = nn.Linear(5, embed_dim)  # Clip Feature Embedding
        self.pos_embed = nn.Embedding(block_size, embed_dim)  # Positional Embedding
        self.block_size = block_size

    def forward(self, idx, clip_feature):
        # Check for empty idx
        #print(f'idx: {idx}')
        if len(idx) == 0:
            token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
        else:
            b, t = idx.size()
            assert t <= self.block_size, "Cannot forward, model block size is exhausted."
            
            # Get token embeddings
            token_embeddings = self.tok_emb(idx)
            
            # Concatenate the clip_feature embeddings at the beginning
            token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
        
        # Add positional embeddings
        positions = torch.arange(0, token_embeddings.size(1)).unsqueeze(0).to(idx.device)
        pos_embeddings = self.pos_embed(positions)
        x = token_embeddings + pos_embeddings

        return x


In [2]:
# Initialize the model
model = SimplifiedCrossCondTransBase()

# Mock data
idx = torch.tensor([[1, 2, 3], [4, 5, 6]])  # Mock token indices for 2 samples in batch
clip_feature = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]])  # Mock clip features for 2 samples in batch

# Forward pass
output = model(idx, clip_feature)

print(output.shape)  # Should print torch.Size([2, 4, 512]) since for each sample, we now have 4 tokens (1 for clip feature + 3 original tokens)


torch.Size([2, 4, 512])


In [4]:
a,b = idx.size()
b

3

In [5]:
import torch
import torch.nn as nn

class SimplifiedCrossCondTransBase(nn.Module):
    
    def __init__(self, num_vq=1024, embed_dim=512, block_size=16):
        super().__init__()
        self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)  # Token Embedding, +2 for padding
        self.cond_emb = nn.Linear(5, embed_dim)  # Clip Feature Embedding ## output shape
        self.pos_embed = nn.Embedding(block_size, embed_dim)  # Positional Embedding
        self.block_size = block_size

    def forward(self, idx, clip_feature):
        # Check for empty idx
        if len(idx) == 0:
            token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
            print("Token embeddings shape (when idx is empty):", token_embeddings.shape)
        else:
            b, t = idx.size()
            assert t <= self.block_size, "Cannot forward, model block size is exhausted."
            
            # Get token embeddings
            token_embeddings = self.tok_emb(idx)
            print("Token embeddings shape (before concat):", token_embeddings.shape)
            
            # Concatenate the clip_feature embeddings at the beginning
            token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
            print("Token embeddings shape (after concat):", token_embeddings.shape)
        
        # Add positional embeddings
        positions = torch.arange(0, token_embeddings.size(1)).unsqueeze(0).to(idx.device)
        pos_embeddings = self.pos_embed(positions)
        print("Positional embeddings shape:", pos_embeddings.shape)
        
        x = token_embeddings + pos_embeddings
        print("Final output shape:", x.shape)

        return x

# Initialize the model
model = SimplifiedCrossCondTransBase()

# Mock data
idx = torch.tensor([[1, 2, 3], [4, 5, 6]])  # Mock token indices for 2 samples in batch
clip_feature = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]])  # Mock clip features for 2 samples in batch

# Forward pass
output = model(idx, clip_feature)


Token embeddings shape (before concat): torch.Size([2, 3, 512])
Token embeddings shape (after concat): torch.Size([2, 4, 512])
Positional embeddings shape: torch.Size([1, 4, 512])
Final output shape: torch.Size([2, 4, 512])


In [6]:
import torch
import torch.nn as nn

# Dummy data
batch_size = 2
embed_dim = 3
clip_feature = torch.randn(batch_size, 5)  # [2, 5]
token_embeddings = torch.randn(batch_size, 4, embed_dim)  # [2, 4, 3]

# Linear layer to simulate self.cond_emb
cond_emb = nn.Linear(5, embed_dim)

# Embed the clip_feature
clip_feature_embedding = cond_emb(clip_feature).unsqueeze(1)  # [2, 1, 3]

# Concatenate
concatenated = torch.cat([clip_feature_embedding, token_embeddings], dim=1)  # [2, 5, 3]

print("Shape of token_embeddings:", token_embeddings.shape)
print("Shape of clip_feature_embedding:", clip_feature_embedding.shape)
print("Shape after concatenation:", concatenated.shape)


Shape of token_embeddings: torch.Size([2, 4, 3])
Shape of clip_feature_embedding: torch.Size([2, 1, 3])
Shape after concatenation: torch.Size([2, 5, 3])


In [13]:
ckpt_kit = torch.load('../pretrained/VQVAE_KIT/net_best_fid.pth', map_location ='cpu')
codebook = ckpt_kit['net']['vqvae.quantizer.codebook']

In [9]:
codebook.shape

torch.Size([512, 512])

In [None]:
codebook[0]

In [11]:
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self, input_size):
        super(SimpleModel, self).__init__()
        self.weights = nn.Parameter(torch.randn(input_size))

    def forward(self, x):
        return x * self.weights

model = SimpleModel(5)
for param in model.parameters():
    print(param)


Parameter containing:
tensor([-0.9491, -0.8419,  0.5876,  0.5137, -0.7643], requires_grad=True)


In [None]:
import torch

# Assuming codebook is loaded and its shape is [512, 512]
#codebook = torch.randn(512, 512)

# Generate a random idx tensor of shape [batch_size, sequence_length]
# where each value is an index to the codebook.
batch_size = 10
sequence_length = 20
idx = torch.randint(0, 512, (batch_size, sequence_length))  # random indices between 0 and 511
print(f'idx: {idx}')
# Fetch the embeddings for each idx
token_embeddings = torch.index_select(codebook, 0, idx.view(-1)).view(idx.size(0), idx.size(1), -1)

# Print the shape of token_embeddings
print(token_embeddings.shape)


In [28]:
idx.size(0)

10

In [21]:
idx

tensor([[279, 267, 101, 237, 422, 213, 192,  58, 160, 391, 117, 198,  18,   2,
         457, 478, 453, 139, 502, 160],
        [502, 243,  91, 491, 446, 388, 250, 195, 241, 247, 436, 124, 132, 359,
          50, 189, 347, 248, 395, 422],
        [170, 446, 201, 234, 276, 419, 225,  65, 362, 365, 380, 458, 279, 130,
         287, 497, 406,   1, 448, 463],
        [262, 381, 258, 491,  38, 135, 209, 134, 358, 171, 271,  21,  70, 491,
         344, 504,  28,  58, 164,  36],
        [ 66, 459, 297,  37, 188,  24, 494, 322, 242, 165, 401, 440, 252, 192,
          39, 245, 326, 439, 187, 468],
        [200, 107, 225, 195, 255, 347, 320, 501, 260, 369, 310, 329,  30, 166,
         250,  92, 457, 139, 268,  28],
        [175, 474, 266,  37, 272, 471, 160, 360,  45, 326, 324, 241, 167,  13,
         311, 282,  66, 396, 169,   8],
        [ 54,  60,  61, 136, 151, 371, 243, 379, 124, 243, 124, 499, 456,  95,
          99, 304,  46, 103, 445, 134],
        [282, 317, 148, 235, 151,  56, 458, 506,

In [None]:
token_embeddings

In [27]:
em1 = torch.index_select(codebook, 0, idx.view(-1))

In [26]:
em1.shape

torch.Size([512, 200])

In [29]:
tensor = torch.rand(2, 3)
reshaped = tensor.view(3, 2)  # Reshape it to [3, 2]
reshaped

tensor([[0.8887, 0.0367],
        [0.9229, 0.0377],
        [0.7410, 0.6245]])

In [32]:
tensor = torch.rand(2, 3)
print(tensor)
reshaped = tensor.reshape(3, 2)  # Reshape it to [3, 2]
print(reshaped)

tensor([[0.8400, 0.3243, 0.2485],
        [0.2472, 0.8400, 0.4535]])
tensor([[0.8400, 0.3243],
        [0.2485, 0.2472],
        [0.8400, 0.4535]])


In [31]:
tensor = torch.rand(2, 3)
reshaped = tensor.reshape(-1)  # Reshape it to [3, 2]
reshaped

tensor([0.8169, 0.9610, 0.7981, 0.0943, 0.6440, 0.8018])