In [46]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import models
import torchtext.transforms as TT
from torch.utils.data import Dataset, DataLoader
from typing import List, Any
import torchvision.transforms as IT
import re
from torchtext.vocab import vocab
from faker import Faker
import random
import pandas as pd

faker = Faker()

import sys

sys.path.insert(0, "..")

from trainer import Trainer


def LDtoDL(l):
    result = {}
    for d in l:
        for k, v in d.items():
            result[k] = result.get(k, []) + [v]
    return result


def DLtoLD(d):
    if not d:
        return []
    result = [{} for i in range(max(map(len, d.values())))]
    for k, seq in d.items():
        for oneDict, oneValue in zip(result, seq):
            oneDict[k] = oneValue
    return result


class TestingDataset(Dataset):
    def __init__(self, data: List[dict]) -> None:
        self.data = data

    def __getitem__(self, index) -> Any:
        return self.data[index]

    def __len__(self) -> int:
        return len(self.data)


class TextTransform(nn.Module):
    def __init__(self):
        super(TextTransform, self).__init__()
        self.toTensor = TT.ToTensor()
        self.padding = TT.PadTransform(20, 0)

    def forward(self, x):
        x = self.toTensor(x)
        x = x[:20]
        x = self.padding(x)
        return x


class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, inputs):
        embedded = self.embedding(inputs)
        out, _ = self.lstm(embedded)
        out = self.fc(out[:, -1, :])
        return F.softmax(out)


class Resnet18(nn.Module):
    def __init__(self, num_classes=2):
        super(Resnet18, self).__init__()
        self.model = models.resnet18(weights="ResNet18_Weights.DEFAULT")
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, inputs):
        out = self.model(inputs)
        return F.softmax(out)


class EnsembleModel(nn.Module):
    def __init__(
        self,
        text_model: LSTMClassifier,
        image_model: Resnet18,
    ):
        super().__init__()
        self.text_model = text_model
        self.image_model = image_model
        self.metadata_model = image_model
        self.fc = nn.Linear(
            text_model.fc.out_features + image_model.model.fc.out_features, 2
        )

    def forward(self, text, image, metadata):
        text_out = self.text_model(text)
        image_out = self.image_model(image)
        combined = torch.cat((text_out, image_out), dim=1)
        out = self.fc(combined)
        return out


image_transform = IT.Compose(
    [
        IT.Resize(256),
        IT.CenterCrop(224),
        IT.ToTensor(),
        IT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

text_transform = TextTransform()

In [48]:
def text_proprocess(texts):
    return re.match(r"[a-z, ]*", texts)[0].split()


vc = vocab({}, specials=["<pad>", "<unk>"])
vc.set_default_index(vc["<unk>"])

In [51]:
df_lenght = 10000
mock_df = pd.DataFrame(
    data={
        "post_message": [faker.text() for _ in range(df_lenght)],
        # "num_like_post": [random.randint(0, 10000) for _ in range(df_lenght)],
        # "user_id": [faker.name() for _ in range(df_lenght)],
        # "num_comment_post": [random.randint(0, 10000) for _ in range(df_lenght)],
        # "num_share_post": [random.randint(0, 10000) for _ in range(df_lenght)],
        "label": [random.randint(0, 1) for _ in range(df_lenght)],
    }
)

mock_df["post_message"] = mock_df["post_message"].map(lambda x: x.lower())
mock_df["images"] = mock_df["post_message"].map(lambda x: torch.rand(3, 256, 256))
mock_dataset = list(mock_df.to_dict("index").values())

In [39]:
train_dataset = TestingDataset(mock_dataset[: int(len(mock_dataset) * 0.8)])
test_dataset = TestingDataset(mock_dataset[int(len(mock_dataset) * 0.8) :])

In [41]:
train_iter = DataLoader(train_dataset, batch_size=4)
test_iter = DataLoader(test_dataset, batch_size=4)

In [50]:
for batch in train_iter:
    clean_text = [text_proprocess(text) for text in batch["post_message"]]
    print(clean_text)
    break

[['question', 'main', 'manager', 'community', 'behind', 'policy', 'left'], ['travel', 'wait', 'white', 'quality', 'finish'], ['though', 'few', 'court', 'billion', 'organization', 'physical', 'nearly'], ['note', 'on', 'baby', 'public', 'on', 'put', 'entire']]
