In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
from tqdm.auto import tqdm
import os
import cv2
import nltk
from datasets import load_dataset, Image
import requests
import PIL.Image as pillow_image
from nltk import tokenize
from torchvision import transforms
import re
from collections import Counter

In [None]:
class Config:
    num_epochs = 30
    model_output_path = "shiryoku_icm"
    model_filename = "shiryoku_vision.pth"
    lr = 0.01
    momentum = 0.9
    batch_size = 64
    top_k = 2
    num_experts = 4
    image_size = 224
    train_size = 0.95
    embed_size = 512
    hidden_size = 1024
    vocab_size = 50000

In [None]:
# For image data
def image_transforms(image_file):
    trasnformed_image = transforms.Compose(
        [
            transforms.RandomCrop(Config.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

    return trasnformed_image


def read_img(image_file):
    image = cv2.imread(image_file)
    image = cv2.resize(image, (224, 224))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image.transpose(2, 0, 1)
    return image


def get_image_from_url(url):
    try:
        url_content = requests.get(url, stream=True).raw
        image = pillow_image.open(url_content)
        return image
    except Exception as e:
        print(e)


# For text data
def tokenize_text(text_input):
    text_tokens = tokenize.word_tokenize(str(text_input).lower())

    return text_tokens


def preprocess_text(text):
    text = re.sub(r"[^a-z0-9\s]", "", text)
    tokens = tokenize_text(text)
    tokens = [t for t in tokens if t not in nltk.corpus.stopwords.words("english")]
    return tokens


def create_vocabulary(text_dataset):
    all_tokens = []
    for text in text_dataset:
        tokens = preprocess_text(text)
        all_tokens.extend(tokens)

    vocab_counter = Counter(all_tokens)
    vocab = [word for word, count in vocab_counter.most_common() if count >= 1]
    vocab = ["<pad>", "<start>", "<end>", "<unk>"] + vocab

    word_to_idx = {word: idx for idx, word in enumerate(vocab)}
    idx_to_word = {idx: word for idx, word in enumerate(vocab)}

    return word_to_idx, idx_to_word


# for fetching the required datasets
def get_dataset(data_split):
    docci_dataset = load_dataset("google/docci")

    data = docci_dataset[data_split]  # type: ignore

    image = data["image"]
    descriptions = data["description"]

    images = [Image(x) for x in image]
    img_captions = [caption for caption in descriptions]

    return images, img_captions


def get_moondream_dataset():
    moondream_dataset = load_dataset("isidentical/moondream2-coyo-5M-captions")
    md_data = moondream_dataset["train"]  # type: ignore

    image_urls = md_data["url"]  # type: ignore
    descriptions = md_data["moondream2_caption"]  # type: ignore

    images = [get_image_from_url(img_url) for img_url in image_urls]
    captions = [caption for caption in descriptions]

    return images, captions

In [None]:
# Models

class ConvNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU().nn.Conv2d(32, 64, kernel_size=3),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 256),
            nn.Linear(256, 256),
            nn.BatchNorm1d(128, momentum=0.01),
        )

    def forward(self, image):
        encoded_image = self.conv_net(image)

        return encoded_image


class PretrainedConvNet(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        resnet = models.resnet152(pretrained=True)
        resnet_modules = list(resnet.children())[:-1]

        self.resnet = nn.Sequential(*resnet_modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.batch_norm = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)

        features = features.reshape(features.size(0), -1)
        features = self.batch_norm(self.linear(features))

        return features


class TextRNNDecoder(nn.Module):

    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers=4, max_len=20):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_len = max_len

    def forward(self, features, captions, lengths):
        text_embed = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), text_embed), dim=2)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        hidden, _ = self.lstm(packed)
        rnn_output = self.linear(hidden[0])

        return rnn_output

    def sample(self, features, state=None):
        sampled = []
        input_seq = features.unsqueeze(1)
        for _ in range(self.max_len):
            hidden, state = self.lstm(input_seq, state)
            outputs = self.linear(hidden.squeeze(1))
            _, predicted = outputs.max(1)
            sampled.append(predicted)
            input_seq = self.embed(predicted)
            input_seq = input_seq.unsqueeze(1)

        sampled = torch.stack(sampled, 1)

        return sampled

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

images, captions = get_moondream_dataset()

captions_vocab = create_vocabulary(captions)


class ImageCaptionData(Dataset):
    def __init__(self, images, captions, transforms, device):
        super().__init__()
        self.images = images
        self.captions = captions
        self.transform = transforms
        self.device = device

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        image = read_img(self.images[idx])
        if self.transform:
            image = self.transform(image)

        image = torch.tensor(image, dtype=torch.float32).to(self.device)

        caption = tokenize_text(self.captions[idx])

        caption = torch.tensor(caption).to(self.device)

        return image, caption


dataset = ImageCaptionData(
    images=images, captions=captions_vocab, transforms=image_transforms, device=device
)

train_size = 0.95 * len(dataset)
val_size = len(dataset) - train_size


In [None]:
dataset = ImageCaptionData(
    images=images, captions=captions_vocab, transforms=image_transforms, device=device
)

train_size = 0.95 * len(dataset)

val_size = len(dataset) - train_size

train_data, valid_data = random_split(dataset, (train_size, val_size))

train_loader = DataLoader(train_data, batch_size=Config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=Config.batch_size, shuffle=False)

print(len(dataset))

In [None]:
encoder = PretrainedConvNet(embed_size=Config.embed_size)
decoder = TextRNNDecoder(
    vocab_size=Config.vocab_size,
    embed_dim=Config.embed_size,
    hidden_size=Config.hidden_size,
)
criterion = nn.CrossEntropyLoss()
parameters = (
    list(decoder.parameters())
    + list(encoder.linear.parameters())
    + list(encoder.batch_norm.parameters())
)
optimizer = optim.Adam(params=parameters, lr=Config.lr)
epochs = Config.num_epochs

In [None]:
def training_loop(train_loader, lossfn, optimizer, epochs=epochs):
    for epoch in tqdm(range(epochs)):
        print(f"Training epoch {epoch}")
        for _, (images, captions, lengths) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = lossfn(outputs, targets)

            decoder.zero_grad()
            encoder.zero_grad()
            loss.backwards()
            optimizer.step()

            print(f"Epoch {epoch} of {epochs}, loss: {loss.item():.4f}")

        print(f"End metrics for {epoch} of {epochs}, loss: {loss.item():.4f}")

        torch.save(
            decoder.state_dict(),
            os.path.join(Config.model_output_path, f"decoder_{epoch}.pth"),
        )

        torch.save(
            encoder.state_dict(),
            os.path.join(Config.model_output_path, f"encoder_{epoch}.pth"),
        )

        print(f"Epoch {epoch} complete")

    print(f'Training complete')
    torch.save(
            decoder.state_dict(),
            os.path.join(Config.model_output_path, f"decoder_{Config.model_filename}"),
        )

    torch.save(
            encoder.state_dict(),
            os.path.join(Config.model_output_path, f"encoder_{Config.model_filename}"),
        )
    print('Models saved')