In [None]:
import pandas as pd
from datasets import load_dataset
from collections import defaultdict
from tqdm import tqdm
import regex as re
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from torch.nn.utils.rnn import pad_sequence

dataset = load_dataset("coastalcph/tydi_xor_rc")

languages = ['ar', 'ko', 'te']
train = dataset["train"].filter(lambda example: example['lang'] in languages).to_pandas()
val = dataset["validation"].filter(lambda example: example['lang'] in languages).to_pandas()



In [None]:
def Tokenize(sentence):
    return re.findall(r"\w+", sentence.lower())

class Vocabulary():

    
    def __init__(self, texts):
        counter = Counter()
        for t in texts:
            counter.update(Tokenize(t))
        self.itos = ["<unk>", "<bos>", "<sep>", "<eos>"] + [w for w, c in counter.items()]
        self.stoi = {w: i for i, w in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)
    
    def Encode(self, sentence):
        return [self.stoi.get(word, 0) for word in Tokenize(sentence)]
    

class Data(Dataset):
    def __init__(self, df, vocab):
        self.data = df
        self.vocab = vocab
    
    def __getitem__(self, n):
        row = self.data.iloc[n]
        question = self.vocab.Encode(row["question"])
        context = self.vocab.Encode(row["context"])
        x = torch.tensor([1] + question + [2] + context + [3]) 
        y = torch.tensor(1 if row["answerable"] else 0) 
        return x, y
    
    def __len__(self):
        return len(self.data)
    

In [48]:
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 1)  # binary classification
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        emb = self.embedding(x)
        _, (h, _) = self.lstm(emb)
        h = torch.cat((h[-2], h[-1]), dim=1)  # concat last fwd + bwd
        return self.sigmoid(self.fc(h)).squeeze()
    
def collate_fn(batch):
    xs, ys = zip(*batch)
    xs = pad_sequence(xs, batch_first=True, padding_value=0)
    ys = torch.stack(ys)
    return xs, ys


In [53]:
def TrainLSTM(train, val):
    vocab = Vocabulary(train["question"].tolist() + train["context"].tolist())

    train_ds = Data(train, vocab)
    val_ds = Data(val, vocab)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=32, collate_fn=collate_fn)

    model = LSTMClassifier(len(vocab))
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(5):
        model.train()
        for x, y in tqdm(train_loader):
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y.float())
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}, train loss = {loss.item():.4f}")

    return model , val_loader

def ValidateModel(model, val_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            preds = (model(x) > 0.5).long()
            correct += (preds == y).sum().item()
            total += y.size(0)

    print("Validation accuracy:", correct / total)


In [None]:
model, val_loader = TrainLSTM(train[train["lang"] == "ar"], val[val["lang"] == "ar"])
ValidateModel(model, val_loader)
#0.9566
model, val_loader = TrainLSTM(train[train["lang"] == "ko"], val[val["lang"] == "ko"])
ValidateModel(model, val_loader)
#0.9354
model, val_loader = TrainLSTM(train[train["lang"] == "te"], val[val["lang"] == "te"])
ValidateModel(model, val_loader)
#0.7943


100%|██████████| 80/80 [01:09<00:00,  1.15it/s]


Epoch 1, train loss = 0.0232


100%|██████████| 80/80 [01:01<00:00,  1.29it/s]


Epoch 2, train loss = 0.2467


100%|██████████| 80/80 [01:12<00:00,  1.11it/s]


Epoch 3, train loss = 0.0292


100%|██████████| 80/80 [01:05<00:00,  1.22it/s]


Epoch 4, train loss = 0.0836


100%|██████████| 80/80 [00:55<00:00,  1.43it/s]


Epoch 5, train loss = 0.0205
Validation accuracy: 0.9566265060240964


100%|██████████| 76/76 [00:49<00:00,  1.54it/s]


Epoch 1, train loss = 0.0293


100%|██████████| 76/76 [00:41<00:00,  1.85it/s]


Epoch 2, train loss = 0.1161


100%|██████████| 76/76 [00:40<00:00,  1.86it/s]


Epoch 3, train loss = 0.0253


100%|██████████| 76/76 [00:45<00:00,  1.66it/s]


Epoch 4, train loss = 0.0112


100%|██████████| 76/76 [00:39<00:00,  1.91it/s]


Epoch 5, train loss = 0.0020
Validation accuracy: 0.9353932584269663


100%|██████████| 43/43 [00:22<00:00,  1.94it/s]


Epoch 1, train loss = 0.5700


100%|██████████| 43/43 [00:24<00:00,  1.72it/s]


Epoch 2, train loss = 0.0275


100%|██████████| 43/43 [00:18<00:00,  2.34it/s]


Epoch 3, train loss = 0.0169


100%|██████████| 43/43 [00:19<00:00,  2.24it/s]


Epoch 4, train loss = 0.1392


100%|██████████| 43/43 [00:19<00:00,  2.20it/s]


Epoch 5, train loss = 0.0118
Validation accuracy: 0.7942708333333334
