In [None]:
import pandas as pd
import torch
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
documents_df = pd.read_pickle("/content/drive/MyDrive/CNN/documents_embeding_df_full.pkl")
contexts_df = pd.read_pickle("/content/drive/MyDrive/CNN/context_query_question_embeding_df_full.pkl")

In [None]:
contexts_df.columns

Index(['level_0', 'index', 'title', 'search_query', 'search_query_embed',
       'question_embed', 'context_embed'],
      dtype='object')

In [None]:
train_dataset = contexts_df[:63994]
test_dataset = contexts_df[63994:71996]
validation_dataset = contexts_df[71996:]

In [None]:
train_dataset

Unnamed: 0,level_0,index,title,search_query_embed,question_embed
0,0,0,آرسنال,"[0.06326603889465332, -0.31802576780319214, 0....","[-0.29263144731521606, -0.36100441217422485, 0..."
1,1,1,آرسنال,"[0.033652883023023605, -0.8137640953063965, -0...","[-0.17496638000011444, -0.39189279079437256, 0..."
2,2,2,آرسنال,"[-0.07896361500024796, 0.24103207886219025, 0....","[-0.41371777653694153, -0.252204030752182, 0.6..."
3,3,3,آرسنال,"[0.011871003545820713, -0.7781738638877869, 0....","[-0.09641078859567642, -0.1868000030517578, 0...."
4,4,4,آرسنال,"[0.3523325026035309, -0.6895809173583984, 0.16...","[-0.26271504163742065, -0.5740457773208618, 0...."
...,...,...,...,...,...
63989,63989,63989,ناصرالدین الطوسی,"[0.31674644351005554, 0.357119619846344, 0.170...","[0.05927535891532898, 0.4965742230415344, 0.54..."
63990,63990,63990,ناصرالدین الطوسی,"[0.34406182169914246, -0.014232856221497059, 0...","[0.15138590335845947, 0.15712127089500427, 0.5..."
63991,63991,63991,ناصرالدین الطوسی,"[0.23040416836738586, -0.6868681311607361, 0.4...","[0.08558045327663422, -0.20832280814647675, 0...."
63992,63992,63992,ناصرالدین الطوسی,"[-0.8551221489906311, 0.9815458059310913, 0.06...","[0.5811823606491089, 0.7674266695976257, 1.032..."


In [None]:
documents_df

Unnamed: 0,title,document_embeds
0,Iran–Iraq,"[-4.206642150878906, -0.9827616214752197, -30...."
1,آب,"[-3.42531681060791, -4.107446193695068, -4.057..."
2,آبادان,"[-6.622650146484375, 2.296736717224121, -4.247..."
3,آب‌انبار,"[-1.873834252357483, -1.6240180730819702, -2.8..."
4,آتش,"[-0.43937796354293823, -3.0440051555633545, -0..."
...,...,...
1118,یوتیوب,"[-23.651025772094727, -6.5472412109375, -10.97..."
1119,یونیورسال استودیوز,"[-6.778399467468262, -4.727050304412842, -8.10..."
1120,یونیکس,"[-9.365889549255371, -4.257607936859131, -3.79..."
1121,یوهان سباستیان باخ,"[-13.516144752502441, -10.824541091918945, -9...."


In [None]:
documents_df = documents_df.drop(['context'], axis=1)
contexts_df= contexts_df.drop(['id'], axis=1)
contexts_df= contexts_df.drop(['answers'], axis=1)
contexts_df= contexts_df.drop(['question'], axis=1)
contexts_df= contexts_df.drop(['context'], axis=1)
contexts_df= contexts_df.drop(['search_query'], axis=1)
contexts_df= contexts_df.drop(['context_embed'], axis=1)

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, query, document):
        seq_len = document.size(1)
        query = query.unsqueeze(1).repeat(1, seq_len, 1)
        concat = torch.cat((query, document), dim=2)
        energy = torch.tanh(self.attn(concat))
        attention = F.softmax(self.v(energy), dim=1)
        weighted = torch.bmm(attention.transpose(1, 2), document)
        output = torch.cat((weighted, query), dim=2)
        return output

class SemanticSearchModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SemanticSearchModel, self).__init__()
        self.hidden_size = hidden_size
        self.query_fc = nn.Linear(input_size, hidden_size)
        self.document_fc = nn.Linear(input_size, hidden_size)
        self.attention = Attention(hidden_size)
        self.fc = nn.Linear(hidden_size * 2, 1)

    def forward(self, query, document):
        query = F.relu(self.query_fc(query))
        document = F.relu(self.document_fc(document))
        attention_output = self.attention(query, document)
        output = F.relu(self.fc(attention_output))
        return output.squeeze()

In [None]:
class DocumentDataset(torch.utils.data.Dataset):
    def __init__(self, queries, documents, labels):
        self.queries = queries
        self.documents = documents
        self.labels = labels
        
    def __len__(self):
        return len(self.queries)
    
    def __getitem__(self, idx):
        query = self.queries[idx]
        document = self.documents[idx]
        label = self.labels[idx]
        
        return query, document, label

# Create training and validation data loaders
train_dataset = DocumentDataset(train_queries, train_documents, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = DocumentDataset(val_queries, val_documents, val_labels)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)