In [None]:
# Parameters
batch_size = 2
epochs = 10
learning_rate = 1e-2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np


import torch 
import torch.nn as nn 
import torch.nn.functional as F
from tensorflow.keras.datasets import  mnist 




device = torch.device("cpu")

def conv2d_vectorized(x, conv_filter, stride=1, padding=0):
    # x: [H, W], conv_filter: [kH, kW]
    x = x.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]
    x_unf = F.unfold(x, kernel_size=conv_filter.shape, stride=stride, padding=padding)
    # x_unf: [1, kH*kW, out_H*out_W]
    conv_flat = conv_filter.flatten().unsqueeze(1)  # [kH*kW, 1]
    out = torch.matmul(conv_flat.T, x_unf)  # [1, out_H*out_W]
    out_H = (x.shape[2] + 2*padding - conv_filter.shape[0]) // stride + 1
    out_W = (x.shape[3] + 2*padding - conv_filter.shape[1]) // stride + 1
    return out.view(1, out_H, out_W).squeeze(0)


def conv2d(x, conv_filter, stride, padding):
    H, W = x.shape 

    x_padded = torch.zeros(H+2*padding, W+2*padding, device=device)
    x_padded[padding:H+padding, padding:W+padding] = x 

    kH, kW = conv_filter.shape
    
    out_H = (H + 2*padding - kH)//stride + 1 
    out_W = (W + 2*padding - kW)//stride + 1 

    output_map = torch.zeros(out_H, out_W, device=device)

    for i in range(0, out_H*stride, stride):
        for j in range(0, out_W*stride, stride):
            output_map[i//stride, j//stride] = torch.sum((x_padded[i:i+kH, j:j+kW] * conv_filter))
    return output_map


class ConvolutionLayer(nn.Module):

    def __init__(self, input_channels, output_channels, padding, stride, filter_size):
        super().__init__()

        #self.filters = nn.Parameter(torch.randn(number_of_filter, filter_size, filter_size) * 0.1)
        self.filters = nn.Parameter(torch.randn(output_channels, input_channels, filter_size, filter_size, device=device) * 0.1)

        self.padding = padding
        self.stride = stride
        self.filter_size = filter_size 


    def forward(self, x):
        B, C, H, W = x.shape  ### [B, C, H, W]

        filters_flat = self.filters.view(self.output_channels, -1) 

        x_unf = F.unfold(x, kernel_size=self.filter_size, padding=self.padding, stride=self.stride)

        filters_flat_exp = filters_flat.unsqueeze(0).expand(B, -1, -1)


        out = torch.bmm(filters_flat_exp, x_unf)
        out_H = (H + 2*self.padding - self.filter_size)//self.stride + 1
        out_W = (W + 2*self.padding - self.filter_size)//self.stride + 1
        out = out.view(B, self.number_of_filter, out_H, out_W)
        return out


        # output = [] 
        
        # for b in range(B):
        #     feature_map = []
        #     for j in range(self.number_of_filter):
        #         #output_conv2d = conv2d(x[b, 0], self.filters[j].to(device),  self.stride, self.padding)
        #         output_conv2d = conv2d_vectorized(x[b, 0], self.filters[j].to(device),  self.stride, self.padding)
        #         feature_map.append(output_conv2d) ##[1, out_H, out_W]
        #     output.append(torch.stack(feature_map)) ##[F, out_H, out_W]
        # return torch.stack(output) ##[B, F, out_H, out_W]
    

class CNNMOdel(nn.Module):

    def __init__(self, classes):
        super().__init__()

        self.conv1 = ConvolutionLayer(input_channels=1, output_channels=2 padding=1, stride=1, filter_size=3)
        self.conv2 = ConvolutionLayer(input_channels=2, output_channels=2, padding=1, stride=1, filter_size=3)
        num_features = 2*28*28
        self.classifier = nn.Linear(num_features, classes, device=device)


    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)

        x = F.relu(x)
        x = x.view(x.size(0), -1)  # flatten [B, 8*28*28]
        x = self.classifier(x) 
        return x




(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images[:10]
train_labels = train_labels[:10]
test_images = test_images[:10]
test_labels = test_labels[:10]




# Convert to float tensors and normalize
train_images = torch.tensor(train_images, dtype=torch.bfloat16) / 255.0
test_images = torch.tensor(test_images, dtype=torch.bfloat16) / 255.0
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)

# Add channel dimension: (N, C, H, W)
train_images = train_images.unsqueeze(1)  # (N, 1, 28, 28)
test_images = test_images.unsqueeze(1)


# DataLoader
train_dataset = torch.utils.data.TensorDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


# Model
model = CNNMOdel(classes=10)
model  = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

total_steps = epochs * train_loader.__len__()

steps = 0 

# Training
for epoch in range(epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        steps+=1

        if steps%5==0:
            print (f"steps {steps} Loss {loss.item()}")
            
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")


In [None]:
from torch.utils.data import Dataset, DataLoader
from collections import Counter


from torchvision.datasets import CocoCaptions, CocoDetection
from torchvision import transforms
import torch 

device = torch.device("mps")



transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

train_dataset_cocooptions = CocoCaptions(
    root='train2017',
    annFile='annotations/captions_train2017.json',
    transform=transform
)


train_dataset_detection = CocoDetection(
    root='train2017',
    annFile='annotations/instances_train2017.json',
    transform=transform
)


N = 1000

from torch.utils.data import Subset
train_dataset_cocooptions = Subset(train_dataset_cocooptions, range(N))
train_dataset_detection = Subset(train_dataset_detection, range(N))


all_captions = "\n".join([caption for captions_list in train_dataset_cocooptions for caption in captions_list[1]])
all_words = list(all_captions.split(" "))



counter = Counter()
for word in all_words:
    counter[word]+=1

vocab = [word for word, cnt in counter.items() if cnt>5]
vocab +=["UNK", "<START>", "<END>", "<PAD>"]


word2idx =  {item:i for i, item in enumerate(vocab)}
idx2word =  {i:item for i, item in enumerate(vocab)}


def encode(stri):
    all_tensor = [word2idx.get(word, word2idx["UNK"]) for word in stri.split(" ")]
    return all_tensor 

def decode(input_tensor):
    return [idx2word[each] for each in input_tensor]
    

class DataLoaderLite(Dataset):

    def __init__(self, train_dataset_cocooptions, caption_length=50):
        self.train_dataset_cocooptions = train_dataset_cocooptions
        self.caption_length = caption_length

    def __len__(self):
        return len(self.train_dataset_cocooptions)
    
    def __getitem__(self, idx):
        image_tensor, image_captions = self.train_dataset_cocooptions[idx]
        caption = "<START> " + image_captions[0] + " <END>"
        caption_tensor = encode(caption)

        if len(caption_tensor) < self.caption_length:
            caption_tensor += [word2idx["<PAD>"]] * (self.caption_length - len(caption_tensor))

        else:
            caption_tensor = caption_tensor[:self.caption_length]
        return image_tensor, torch.tensor(caption_tensor)


In [None]:
import torch.nn as nn 
import torch 
from torch.functional import F 
from torch.cuda.amp import GradScaler, autocast

from torch.utils.data import Dataset, DataLoader
from collections import Counter


from torchvision.datasets import CocoCaptions, CocoDetection
from torchvision import transforms
import torch 
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from pydantic import BaseModel


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
special_tokens = {"additional_special_tokens": ["<START>", "<END>"]}
tokenizer.add_special_tokens(special_tokens)


tokenizer.pad_token = "<PAD>"
pad_token_id = tokenizer.convert_tokens_to_ids("<PAD>")
start_token_id = tokenizer.convert_tokens_to_ids("<START>") 
end_token_id = tokenizer.convert_tokens_to_ids("<END>") 


device = torch.device("mps")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset_cocooptions = CocoCaptions(
    root='train2017',
    annFile='annotations/captions_train2017.json',
    transform=transform
)


train_dataset_detection = CocoDetection(
    root='train2017',
    annFile='annotations/instances_train2017.json',
    transform=transform
)


N = 1000


from torch.utils.data import Subset
train_dataset_cocooptions = Subset(train_dataset_cocooptions, range(N))
train_dataset_detection = Subset(train_dataset_detection, range(N))


# all_captions = "\n".join([caption for captions_list in train_dataset_cocooptions for caption in captions_list[1]])
# all_words = list(all_captions.split(" "))



# counter = Counter()
# for word in all_words:
#     counter[word]+=1

# vocab = [word for word, cnt in counter.items() if cnt>5]
# vocab +=["UNK", "<START>", "<END>", "<PAD>"]


# word2idx =  {item:i for i, item in enumerate(vocab)}
# idx2word =  {i:item for i, item in enumerate(vocab)}


# def encode(stri):
#     all_tensor = [word2idx.get(word, word2idx["UNK"]) for word in stri.split(" ")]
#     return all_tensor 

# def decode(input_tensor):
#     return [idx2word[each] for each in input_tensor]


class DataLoaderLite(Dataset):
    def __init__(self, train_dataset_cocooptions, caption_length=50, tokenizer=tokenizer):
        self.train_dataset_cocooptions = train_dataset_cocooptions
        self.caption_length = caption_length
        self.tokenizer = tokenizer 

    def __len__(self):
        return len(self.train_dataset_cocooptions)
    
    def __getitem__(self, idx):
        image_tensor, image_captions = self.train_dataset_cocooptions[idx]

        # prepend <START>, append <END>
        caption = "<START> " + image_captions[0] + " <END>"

        # tokenize with GPT2 tokenizer
        tokens = self.tokenizer(
            caption,
            max_length=self.caption_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return image_tensor, tokens["input_ids"].squeeze(0), tokens["attention_mask"].squeeze(0)
    
    # def __getitem__(self, idx):
    #     image_tensor, image_captions = self.train_dataset_cocooptions[idx]

    #     # Prepend <IMG> + <START>, Append <END>
    #     caption = "<IMG> <START> " + image_captions[0] + " <END>"
    #     caption_tensor = encode(caption)

    #     if len(caption_tensor) < self.caption_length:
    #         caption_tensor += [word2idx["<PAD>"]] * (self.caption_length - len(caption_tensor))
    #     else:
    #         caption_tensor = caption_tensor[:self.caption_length]

    #     return image_tensor, torch.tensor(caption_tensor)

    

# class DataLoaderLite(Dataset):

#     def __init__(self, train_dataset_cocooptions, caption_length=50):
#         self.train_dataset_cocooptions = train_dataset_cocooptions
#         self.caption_length = caption_length

#     def __len__(self):
#         return len(self.train_dataset_cocooptions)
    
#     def __getitem__(self, idx):
#         image_tensor, image_captions = self.train_dataset_cocooptions[idx]
#         caption = "<START> " + image_captions[0] + " <END>"
#         caption_tensor = encode(caption)

#         if len(caption_tensor) < self.caption_length:
#             caption_tensor += [word2idx["<PAD>"]] * (self.caption_length - len(caption_tensor))

#         else:
#             caption_tensor = caption_tensor[:self.caption_length]
#         return image_tensor, torch.tensor(caption_tensor)



device = torch.device("mps")


def conv2d(x, kernel, stride, padding):
    H, W = x.shape
    device = x.device 
    x_padded = torch.zeros((H+2*padding, W+2*padding), device=device)
    x_padded[padding:H+padding, padding:W+padding] = x 

    kH, kW = kernel.shape 
    out_H = (H+2*padding-kH)//stride +1
    out_W = (W+2*padding-kW)//stride +1

    feature_map = torch.zeros((out_H, out_W), device=device)

    for i in range(0, (H+2*padding-kH+1), stride):
        for j in range(0, (W+2*padding-kW+1), stride):
            region = x_padded[i:i+kH, j:j+kW] 
            feature_map[i//stride, j//stride] = torch.sum(kernel.to(region.device) * region)
    return feature_map



class ConvolutionLayer(nn.Module):

    def __init__(self, input_channels, output_channels, padding, stride, kernel_size):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.padding = padding
        self.stride = stride 
        self.kernel_size = kernel_size 
        self.kernel = nn.Parameter(torch.randn(self.output_channels, self.input_channels, self.kernel_size, self.kernel_size))


    def forward(self, x):
        B, C, H, W = x.shape
        output = []
        for batch in range(B):
            output_feature_map = []
            for each_output_channel in range(self.output_channels):
                feature_map = torch.zeros(((H+2*self.padding-self.kernel_size)//self.stride +1, (W+2*self.padding-self.kernel_size)//self.stride +1), device=device)

                for each_input_channel in range(self.input_channels):
                    feature_map += conv2d(x[batch, each_input_channel], self.kernel[each_output_channel, each_input_channel], self.stride, self.padding)
                output_feature_map.append(feature_map)
            output.append(torch.stack(output_feature_map))
        return torch.stack(output)
    


class ResnetGPT2Wrapper(nn.Module):
    def __init__(self, gpt_decoder, embed_size, vocab_size, num_img_tokens=5):
        super().__init__()
        self.gpt_decoder = gpt_decoder
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.num_img_tokens = num_img_tokens


        self.mha = nn.MultiheadAttention(embed_dim=embed_size, num_heads=4, batch_first=True)
        self.key_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.value_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.query_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.layernorm = nn.LayerNorm(embed_size, eps=1e-6)
        self.dropout = nn.Dropout(0.1)
        self.img_queries = nn.Parameter(torch.randn(num_img_tokens, embed_size) * 0.01)


    def forward(self, img_features, captions_tensor, attention_mask=None):
        img_features = img_features.float()
         # 1. Token embeddings from GPT2
        tok_embeds = self.gpt_decoder.transformer.wte(captions_tensor)  # (B, T, D)

        B = tok_embeds.shape[0]

        queries = self.img_queries.unsqueeze(0).expand(B, -1, -1)  # (B, num_img_tokens, D)


        B, T, D = tok_embeds.shape
        N = img_features.shape[1]

        k = self.key_proj(img_features)              # (B, N, D)
        v = self.value_proj(img_features)            # (B, N, D)
        
        enriched, _ = self.mha(self.query_proj(queries), k, v)  # (B, M, D)

        enriched = self.layernorm(queries + enriched) 

        fused = torch.cat([enriched, tok_embeds], dim=1)  # (B, M+T, D)

        # query = self.query_proj(tok_embeds)
        # keys  = self.key_proj(img_features)
        # values = self.value_proj(img_features)
        # enriched, attn_weights = self.mha(query, keys, values)

        enriched = self.dropout(fused)


        # enriched = enriched + tok_embeds  # residual connection

        inputs_embeds = enriched[:, :-1, :].contiguous()
        labels = captions_tensor[:, 1:].contiguous()


        pad_for_img = torch.full((B, self.num_img_tokens, ), pad_token_id, dtype=torch.long, device=labels.device)

        labels = torch.cat([pad_for_img, labels], dim=1)   # (B, M + T - 1)


        if attention_mask is not None:
            img_mask = torch.ones(B, self.num_img_tokens, device=attention_mask.device)
            attention_mask = torch.cat([img_mask, attention_mask], dim=1)
            attention_mask = attention_mask[:, :-1].contiguous()

        outputs = self.gpt_decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs.logits, outputs.loss 

    
class LSTMEncoder(nn.Module):

    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embed_size)
        self.lstm_layer = nn.LSTM(2*embed_size, hidden_size=hidden_size, batch_first=True, num_layers=3)
        self.mha = nn.MultiheadAttention(embed_dim=embed_size, num_heads=4, batch_first=True)
        self.key_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.value_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.query_proj = nn.Linear(embed_size, embed_size, dtype=torch.bfloat16)
        self.fc = nn.Linear(hidden_size, vocab_size)


    def forward(self, features, captions):
        embeddings = self.embedding_layer(captions[:, :-1])  # teacher forcing
        # features = features.unsqueeze(1)  # (B, 1, embed_size)

        
        query = self.query_proj(embeddings)

        keys  = self.key_proj(features) 

        values = self.value_proj(features)

        # keys = keys.unsqueeze(1)
        # values = values.unsqueeze(1)


        attn_out, attn_weights = self.mha(query, keys, values)


        attn_out = torch.cat((embeddings, attn_out), dim=-1)


        # print (f"==== attn_weights", attn_weights.shape)
        # LLLL

        # inputs = torch.cat((attn_out, embeddings), dim=1)
        outputs, _ = self.lstm_layer(attn_out)


        outputs = self.fc(outputs)  # (B, T, vocab_size)

        return outputs



def get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w):
    """Return 2D sine-cosine positional embeddings"""
    grid_y = torch.arange(grid_h, dtype=torch.bfloat16)
    grid_x = torch.arange(grid_w, dtype=torch.bfloat16)
    grid = torch.meshgrid(grid_y, grid_x, indexing='ij')  # (H, W)
    grid = torch.stack(grid, dim=-1)  # (H, W, 2)

    # flatten
    grid = grid.reshape(-1, 2)  # (H*W, 2)

    # compute embeddings
    pos_emb = []
    for dim in range(embed_dim // 2):
        div_term = 10000 ** (2 * (dim // 2) / embed_dim)
        pos_emb.append(torch.sin(grid / div_term) if dim % 2 == 0 else torch.cos(grid / div_term))
    pos_emb = torch.cat(pos_emb, dim=1)  # (H*W, embed_dim)
    return pos_emb


import torchvision.models as models

class ResnetEncoder(nn.Module):
    def __init__(self, embed_size, freeze_until_layer=5):
        super().__init__()
        # load pretrained ResNet
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-2]  # remove the last fc layer
        self.backbone = nn.Sequential(*modules)

        # freeze layers
        child_counter = 0
        for child in self.backbone.children():
            if child_counter < freeze_until_layer:
                for param in child.parameters():
                    param.requires_grad = False
            child_counter += 1
       

        self.fc = nn.Linear(resnet.fc.in_features, embed_size)

        self.embed_size = embed_size


    def forward(self, x):
        # (B, 3, 224, 224) -> (B, 2048, H/32, W/32)
        feats = self.backbone(x)
        B, C, H, W = feats.shape
        feats = feats.view(B, C, -1).permute(0, 2, 1)  # (B, H*W, C)
        feats = self.fc(feats)  # (B, H*W, embed_size)
        return feats


checkpoint_path = "checkpoint.pth"

def save_model(image_encoder, caption_encoder):
    model_dict = {
        "image_encoder_state": image_encoder.state_dict(),
        "caption_encoder_state": caption_encoder.state_dict(),
        "image_encoder_class": image_encoder.__class__,
        "caption_encoder_class": caption_encoder.__class__,
        "image_encoder_args": image_encoder.args if hasattr(image_encoder, "args") else (),
        "caption_encoder_args": caption_encoder.args if hasattr(caption_encoder, "args") else (),
    }
    torch.save(model_dict, checkpoint_path)


def load_model():
    model_dict = torch.load(checkpoint_path, map_location="mps")

    image_encoder = model_dict["image_encoder_class"](*model_dict["image_encoder_args"])
    caption_encoder = model_dict["caption_encoder_class"](*model_dict["caption_encoder_args"])

    image_encoder.load_state_dict(model_dict["image_encoder_state"])
    caption_encoder.load_state_dict(model_dict["caption_encoder_state"])

    return image_encoder, caption_encoder


class CNNEncoder(nn.Module):

    def __init__(self, embed_size, input_shape):
        super().__init__()
   
        # More filters + strides to reduce spatial dims
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)  # 224 -> 112
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 112 -> 56
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 56 -> 28
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) # 28 -> 14

        pos_emb = get_2d_sincos_pos_embed(embed_size, 14, 14)  # (196, embed_size)
        self.register_buffer("pos_embed", pos_emb.unsqueeze(0))  # (1, 196, embed_size)


        self.to(device)
        with torch.no_grad():
            B, C, H, W = input_shape[:]
            x_dummy = torch.randn((B, C, H, W), device=device)
            x_dummy = self.conv1(x_dummy)
            x_dummy = self.conv2(x_dummy)
            x_dummy = self.conv3(x_dummy)
            x_dummy = self.conv4(x_dummy)
            B, C, H, W = x_dummy.shape 
            del x_dummy
            torch.mps.empty_cache()  # if using MPS
            import gc; gc.collect()
        self.fc = nn.Linear(C, embed_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv1(x)
        x = F.relu(x) 
        x = self.conv2(x) 
        x = F.relu(x) 
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        B, C, H, W = x.shape
        x = x.reshape(B, C, H*W)
        x = x.permute(0, 2, 1)  # (B, H*W, C)
        # x_embed = self.fc(x)  # (B, H*W, embed_size
        B, N, C = x.shape
        x_embed = self.fc(x)   # (B, N, embed_size) 
        x_embed = x_embed + self.pos_embed[:, :N, :].to(x_embed.device)  # add positional embedding
        return x_embed  
        



# # caption_encoder = LSTMEncoder(embed_size, hidden_size, vocab_size)
# from transformers import GPT2LMHeadModel
# gpt_decoder = GPT2LMHeadModel.from_pretrained("gpt2")
# gpt_decoder.resize_token_embeddings(gpt_decoder.get_input_embeddings().num_embeddings + 2)  # Example: add 3 tokens
# vocab_size = gpt_decoder.get_input_embeddings().num_embeddings



# gpt_hidden_size = gpt_decoder.config.hidden_size
# embed_size = gpt_hidden_size  # to match GPT2 hidden size
# hidden_size = gpt_hidden_size
# batch_size = 4
# input_channels = 3  
# image_h, image_w = 224, 224
# steps = 0
# epochs = 1
# lr = 1e-5
# accumulation_steps = 4  # simulate batch_size * 2

 

gpt_decoder = GPT2LMHeadModel.from_pretrained("gpt2")
gpt_decoder.resize_token_embeddings(gpt_decoder.get_input_embeddings().num_embeddings + 2)  # Example: add 3 tokens


checkpoint_path = "checkpoint.pth"
device = "mps"

class TrainingConfig(BaseModel):
    gpt_hidden_size = gpt_decoder.config.hidden_size
    embed_size = gpt_hidden_size  # to match GPT2 hidden size
    hidden_size = gpt_hidden_size
    batch_size = 4
    input_channels = 3  
    image_h, image_w = 224, 224
    steps = 0
    epochs = 1
    lr = 1e-5
    accumulation_steps = 4  # simulate batch_size * 2
    vocab_size = gpt_decoder.get_input_embeddings().num_embedding


train_dataset_cocooptions=DataLoaderLite(train_dataset_cocooptions, caption_length=50, tokenizer=tokenizer)           
train_dataloader = DataLoader(train_dataset_cocooptions, batch_size=batch_size, shuffle=True)

total_steps = len(train_dataloader)  * epochs
import math 

formatted_str = f"Training details vocab size {vocab_size} batch size {batch_size} image size {image_h}x{image_w}"
formatted_str+= f" total steps {total_steps} epochs {epochs}"
formatted_str+= f"Max loss {math.log(vocab_size)}"
formatted_str+= f"Perplexity {math.exp(math.log(vocab_size))}"


print (formatted_str)
scaler = GradScaler()

torch.mps.empty_cache()
import gc; gc.collect()

#encoder_model = CNNEncoder(embed_size, [batch_size, input_channels, image_h, image_w])

encoder_model = ResnetEncoder(embed_size)
encoder_model = encoder_model.to(device)



caption_encoder = ResnetGPT2Wrapper(gpt_decoder, embed_size, vocab_size)

caption_encoder = caption_encoder.to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id)




all_params = list([param for param in encoder_model.parameters() if param.requires_grad]) + list(caption_encoder.parameters())

# print ("Trainable parameters in encoder model:")
# print (sum(p.numel() for p in all_params if p.requires_grad))



optimizer = torch.optim.Adam(all_params, lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps/accumulation_steps, eta_min=1e-6)

import time 
start_time = time.time()


for epoch in range(epochs):

    # print (f"Starting epoch {epoch+1}/{epochs}")

    for step, batch in enumerate(train_dataloader):
        image_tensor, caption_tensor, attention_mask = batch[0], batch[1], batch[2] # [B, 3, 224, 224], [B, T], [B, T] 
        image_tensor, caption_tensor, attention_mask = image_tensor.to(device), caption_tensor.to(device), attention_mask.to(device)
        B, C, H, W = image_tensor.shape

        global_step = epoch * len(train_dataloader) + step + 1

        # print (f"Step {step+1}/{len(train_dataloader)} Global step {global_step}/{total_steps}")
     

        with torch.autocast("mps", enabled=True, dtype=torch.bfloat16):

            # print ("Running encoder model")
            # print("current allocated memory:", torch.mps.current_allocated_memory() / 1e9, "GB")
            # print("driver allocated memory:", torch.mps.driver_allocated_memory() / 1e9, "GB")

            x_embed = encoder_model(image_tensor) # (B, N, embed_size) 

            # print("current allocated memory:", torch.mps.current_allocated_memory() / 1e9, "GB")
            # print("driver allocated memory:", torch.mps.driver_allocated_memory() / 1e9, "GB")


            #x_caption = caption_encoder(x_embed, caption_tensor)


            logits, caption_loss  = caption_encoder(x_embed, caption_tensor, attention_mask)  # (B, T-1, vocab_size)

            # B, T, C = logits.shape
            # preds = logits.reshape(B*T, C)
            # targets = caption_tensor[:, 1:].reshape(-1)

            # print (f" prediction {preds.shape} targets: {targets.shape}")
            
            # caption_loss = loss_fn(preds, targets)
            loss = caption_loss / accumulation_steps  

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        all_params = list([param for param in encoder_model.parameters() if param.requires_grad]) + list(caption_encoder.parameters())
        torch.nn.utils.clip_grad_norm_(all_params, max_norm=5.0)

        if (step + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

          # estimate remaining time every 100 steps
        if global_step % 100 == 0:
            elapsed = time.time() - start_time
            steps_per_sec = global_step / elapsed
            remaining_steps = total_steps - global_step
            est_remaining = remaining_steps / steps_per_sec
            est_total = total_steps / steps_per_sec

            print(f"epoch {epoch+1}/{epochs} step {step}/{len(train_dataloader)} "
                  f"Loss: {loss.item()*accumulation_steps:.4f} | "
                  f"Elapsed: {elapsed/60:.2f} min | "
                  f"ETA: {est_remaining/60:.2f} min | "
                  f"Total est: {est_total/60:.2f} min | "
                  f"Memory: {torch.mps.current_allocated_memory() / 1e9:.2f} GB , \ {torch.mps.driver_allocated_memory() / 1e9:.2f} GB | "
                  f"Perplexity {math.exp(loss.item()*accumulation_steps):.2f}"
                  )
            
            save_model(image_encoder=encoder_model, caption_encoder=caption_encoder)

    if (step + 1) % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(all_params, 5.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()


        # if step % 100 == 0:
        #     print(f" epoch {epoch+1}/{epochs} step {step}/{total_steps} Loss: {loss.item()}")

        # if step % 1 == 0:
        #     print("current allocated memory:", torch.mps.current_allocated_memory() / 1e9, "GB")
        #     print("driver allocated memory:", torch.mps.driver_allocated_memory() / 1e9, "GB")


    del image_tensor, caption_tensor, x_embed, logits
    torch.mps.empty_cache()
    import gc; gc.collect()

In [None]:
from transformers import GPT2LMHeadModel
from transformers import GPT2Tokenizer
from transformers import AutoTokenizer
import os 
from transformers import GPTNeoForCausalLM, AutoTokenizer



class GPT2WithCross(GPT2LMHeadModel):

    def forward(
        self,
        input_ids=None,
        inputs_embeds=None,
        attention_mask=None,
        img_feats=None,
        labels=None,
        **kwargs
    ):
        if input_ids is not None:
            inputs_embeds = self.transformer.wte(input_ids)

        # Run through transformer, passing img_feats to each block
        hidden_states = inputs_embeds

        for block in self.transformer.h:
            hidden_states = block(hidden_states, img_feats=img_feats, attention_mask=attention_mask)[0]

        # Layer norm
        hidden_states = self.transformer.ln_f(hidden_states)


        # LM head
        logits = self.lm_head(hidden_states)


        loss = None
        if labels is not None:
            # Shift for language modeling loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )


        return {"loss": loss, "logits": logits}


In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from utils import setup_data
from setup_model import get_models 
from datasetlite import DataLoaderLite 


import math 
import torch 




#----- Model Setup -------

TrainingConfig, encoder_model, decoder_model , pad_token_id, tokenizer = get_models() 



torch.mps.empty_cache()
import gc; gc.collect()



def collate_fn(batch):
    images, input_ids, attention_mask = zip(*batch)
    images = torch.stack(images)
    # pad input_ids and attention_mask to the max length in this batch
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    return images, input_ids, attention_mask


train_dataset_cocooptions, val_dataset_cocooptions, train_dataset_detection , val_dataset_detection = setup_data(TrainingConfig.number_of_items)
train_dataset_cocooptions = DataLoaderLite(train_dataset_cocooptions, caption_length=TrainingConfig.caption_len, tokenizer=tokenizer)
val_dataset_cocooptions = DataLoaderLite(val_dataset_cocooptions, caption_length=TrainingConfig.caption_len, tokenizer=tokenizer)
train_dataloader = DataLoader(train_dataset_cocooptions, batch_size=TrainingConfig.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataloader = DataLoader(val_dataset_cocooptions, batch_size=TrainingConfig.batch_size, collate_fn=collate_fn, shuffle=False)


total_steps = len(train_dataloader)  * TrainingConfig.epochs
formatted_str = f"Training details vocab size {TrainingConfig.vocab_size} batch size {TrainingConfig.batch_size} image size {TrainingConfig.image_h}x{TrainingConfig.image_w}"
formatted_str+= f" total steps {total_steps} epochs {TrainingConfig.epochs}"
formatted_str+= f"Max loss {math.log(TrainingConfig.vocab_size)}"
formatted_str+= f"Perplexity {math.exp(math.log(TrainingConfig.vocab_size))}"

print (formatted_str)


In [None]:
from torch.cuda.amp import GradScaler
import numpy as np 
from utils import calculate_total_train_params, save_to_checkpoint
import torch.nn as nn 



scaler = GradScaler()

loss_list = []


def eval():
    # --------------------
    #  Validation step
    # --------------------
    decoder_model.eval()
    encoder_model.eval()
    val_loss = 0
    count = 0 
    with torch.no_grad():
        for val_batch in val_dataloader:
            image_tensor, caption_tensor, attention_mask = [x.to(device) for x in val_batch]
            with torch.autocast("mps", enabled=True, dtype=torch.bfloat16):
                x_embed = encoder_model(image_tensor)
                _, val_caption_loss = decoder_model(x_embed, caption_tensor, attention_mask)
            val_loss += val_caption_loss.item()
            count+=1
            if count > 2:break 
    val_loss /= count 
    decoder_model.train()
    encoder_model.train()
    print(f"Epoch {epoch+1}: train_loss={total_loss/len(train_dataloader):.4f}, val_loss={val_loss:.4f}")
    return val_loss 



def should_stop(loss_list):
    last_ten_loss = loss_list[-100:]
    threshold = 0.4
    if len(last_ten_loss)==100 and len(loss_list)>=100:
        diffs = np.diff(last_ten_loss)
        step_trends = []
        for d in diffs:
            if d > threshold:
                step_trends.append("increasing")
            elif d < -threshold:
                step_trends.append("decreasing")
            else:
                step_trends.append("steady")

        if all(t == "steady" for t in step_trends):
            return True 
        else:
            print ("Trend", step_trends)
    return False 



##### Setup Training #####
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id)
all_params = calculate_total_train_params(encoder_model, decoder_model)


print (f"Trainable parameters in encoder model: {sum(p.numel() for p in all_params if p.requires_grad)/1e6} M")

optimizer = torch.optim.AdamW(all_params, lr=TrainingConfig.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps/TrainingConfig.accumulation_steps, eta_min=1e-6)

import time 
start_time = time.time()
total_loss = 0 
best_val_loss = float("inf")
epochs_no_improve = 0
steps_no_improve = 0
patience_steps = 10
stop = False 
device = torch.device("mps")


for epoch in range(TrainingConfig.epochs):
    for step, batch in enumerate(train_dataloader):
        image_tensor, caption_tensor, attention_mask = batch[0], batch[1], batch[2] # [B, 3, 224, 224], [B, T], [B, T] 
        image_tensor, caption_tensor, attention_mask = image_tensor.to(device), caption_tensor.to(device), attention_mask.to(device)
        B, C, H, W = image_tensor.shape

        global_step = epoch * len(train_dataloader) + step + 1     

        with torch.autocast("mps", enabled=True, dtype=torch.bfloat16):
            x_embed = encoder_model(image_tensor) # (B, N, embed_size) 
            logits, caption_loss  = decoder_model(x_embed, caption_tensor, attention_mask)  # (B, T-1, vocab_size)
            loss = caption_loss / TrainingConfig.accumulation_steps  

        # x_embed = image_encoder(image_tensor) # (B, N, embed_size) 
        # logits, caption_loss  = caption_encoder(x_embed, caption_tensor, attention_mask)  # (B, T-1, vocab_size)
        # loss = caption_loss / TrainingConfig.accumulation_steps  


        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(all_params, max_norm=5.0)
        if (step + 1) % TrainingConfig.accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

        total_loss += loss.item() * TrainingConfig.accumulation_steps  


        if global_step %10==0:
            val_loss = eval()
            loss_list.append(val_loss)
            save_to_checkpoint(encoder_model, decoder_model, optimizer, epoch, loss, global_step)

            if should_stop(loss_list):
                stop = True 
                break

          # estimate remaining time every 100 steps
        if global_step % 100 == 0:
            elapsed = time.time() - start_time
            steps_per_sec = global_step / elapsed
            remaining_steps = total_steps - global_step
            est_remaining = remaining_steps / steps_per_sec
            est_total = total_steps / steps_per_sec

            print(f"epoch {epoch+1}/{TrainingConfig.epochs} step {step}/{len(train_dataloader)} "
                  f"Loss: {loss.item()*TrainingConfig.accumulation_steps:.4f} | "
                  f"Elapsed: {elapsed/60:.2f} min | "
                  f"ETA: {est_remaining/60:.2f} min | "
                  f"Total est: {est_total/60:.2f} min | "
                  f"Memory: {torch.mps.current_allocated_memory() / 1e9:.2f} GB , \ {torch.mps.driver_allocated_memory() / 1e9:.2f} GB | "
                  f"Perplexity {math.exp(loss.item()*TrainingConfig.accumulation_steps):.2f}"
                  )
            
            # save_model(image_encoder=image_encoder, caption_encoder=caption_encoder)

    if (step + 1) % TrainingConfig.accumulation_steps != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(all_params, 5.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

    if stop:
        break
    
    del image_tensor, caption_tensor, x_embed, logits
    torch.mps.empty_cache()
    import gc; gc.collect()

In [None]:
! export PYTORCH_ENABLE_MPS_FALLBACK=1 

In [None]:
from torch.nn import functional as F
import matplotlib.pyplot as plt 
from transformers.generation.logits_process import LogitsProcessorList
from transformers import LogitsProcessorList, MinLengthLogitsProcessor, RepetitionPenaltyLogitsProcessor

from transformers.generation.logits_process import (
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)

# Chain multiple warpers
warpers = LogitsProcessorList([
    TemperatureLogitsWarper(0.8),
    TopKLogitsWarper(50),
    TopPLogitsWarper(0.95),
])

processors = LogitsProcessorList([
    RepetitionPenaltyLogitsProcessor(penalty=1.2)
])

torch.autograd.set_detect_anomaly(True)



device = "mps"

def generate_caption(image_tensor, max_len=20, use_image=True):
    caption_encoder.eval()
    image_encoder.eval()

    image_tensor = image_tensor.to(device).unsqueeze(0)



    x_embed = image_encoder(image_tensor) 



    start_id = tokenizer.convert_tokens_to_ids("<START>")
    generated_ids = torch.tensor([[start_id]], device=device)

    for _ in range(max_len):
        with torch.no_grad():
            attn_mask = torch.ones(1, generated_ids.shape[1], dtype=torch.bfloat16, device=device)


            logits, _ = caption_encoder(x_embed, generated_ids, attn_mask, mode="test")



            next_logits = logits[:, -1, :]

            next_logits = next_logits.to(torch.bfloat16)

            logits = processors(generated_ids.cpu(), next_logits.cpu())

            probs = F.softmax(next_logits, dim=-1)
        
            # Sample on CPU to avoid MPS issues
            next_token_id = torch.multinomial(probs, num_samples=1).to(device)


        generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

        if next_token_id.item() == tokenizer.convert_tokens_to_ids("<END>"):
            break

 
    caption = tokenizer.decode(generated_ids.squeeze().tolist(), skip_special_tokens=True)

    return caption


counter = 0 
# Example usage
for step, batch in enumerate(train_dataloader):
    image_tensor, caption_tensor, attention_mask = batch[0], batch[1], batch[2] # [B, 3, 224, 224], [B, T], [B, T] 
    image_tensor, caption_tensor, attention_mask = image_tensor.to(device), caption_tensor.to(device), attention_mask.to(device)
    B, C, H, W = image_tensor.shape

    #caption_without_image = generate_caption(encoder_model, caption_encoder, image_tensor[0], tokenizer, use_image=False)

    caption_with_image = generate_caption(image_tensor[0], use_image=True)

    plt.imshow(image_tensor[0].permute(1,2,0).cpu().numpy())
    print("With image context:\t \t", caption_with_image)
    print ("Actual\t\t", tokenizer.decode(caption_tensor[0].tolist(), skip_special_tokens=True))
    # print("Without image context: ", caption_without_image)
    break



In [None]:

globals().get("encoder") or "DD"