In [2]:
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.metrics import f1_score


In [3]:
train_df = pd.read_json('train.json')
test_df = pd.read_json('test.json')


In [5]:
train_df = train_df.reset_index(drop=True)

In [6]:
test_df = test_df.reset_index(drop=True)

In [5]:
train_embedding = torch.load('all_train_embedding.pt')
test_embedding = torch.load('all_test_embedding.pt')

In [13]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    

def pad_tensor_sequence(sequence, max_length, embedding_dim, padding_value=0):
    if sequence.size(0)>max_length:
        sequence = sequence[:max_length,:] ## Take the first max _length vector
    padding = torch.full((max_length - sequence.size(0), embedding_dim), padding_value)
    
    
    padded_sequence = torch.cat((sequence, padding), dim=0)
    
    attn_mask = torch.tensor(sequence.size(0)*[0]+padding.size(0)*[1],dtype=torch.float)
    
    return padded_sequence,attn_mask


def data_collator_with_padding(batch, embedding_dim, padding_value=0,max_length=128):
   
    batch_data_attn_mask = [pad_tensor_sequence(item[0], max_length, embedding_dim, padding_value) for item in batch]
    batch_labels = [torch.nn.functional.one_hot(torch.tensor(item[1],dtype=torch.long),7) for item in batch]
   
    batch_data = [item[0] for item in batch_data_attn_mask]
    batch_attention_mask = torch.stack([item[1] for item in batch_data_attn_mask])
    
    batch_data_tensor = torch.stack(batch_data)
    #print(batch_labels)
    batch_labels_tensor = torch.stack(batch_labels)
    
    #batch_labels_tensor = torch.tensor(batch_labels)

    return batch_data_tensor, batch_labels_tensor.float(), batch_attention_mask
    

# Example usage

#data = [v for v in neg_post_embedding_dict.values()]
#data.extend([p for p in adhd_post_embedding_dict.values()])
#labels = [0]*len(neg_post_embedding_dict)
#labels.extend([1]*len(adhd_post_embedding_dict))


#dataloader = DataLoader(dataset, batch_size=16, collate_fn=lambda batch: data_collator_with_padding(batch, 768))

#for batch_data, batch_labels,batch_attention_mask in a:
    
    #pass#print("Batch data shape:", batch_data.shape)
   # print("Batch labels:", batch_labels)

def create_dataloader_from_post_embedding(post_embedding,df,batch_size):
    data = []
    labels = []
    for author,frame in df.groupby('author'):
        data.append(post_embedding[frame.index])
        labels.append(frame['label'].iloc[0])
        
    dataset = CustomDataset(data,labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: data_collator_with_padding(batch, 768))
    
    return dataloader
        
        
    

In [35]:
train_dataloader = create_dataloader_from_post_embedding(train_embedding,train_df,32)
test_dataloader = create_dataloader_from_post_embedding(test_embedding,test_df,32)

In [36]:
from torch import nn
import torch
class UserEmbedder(nn.Module):
    
    def __init__(self,n_layer=4):
        super().__init__()
        self.layers = nn.ModuleList([nn.MultiheadAttention(768, 6,batch_first=True) for _ in range(n_layer)])
        #self.self_attention = nn.MultiheadAttention(768,4,batch_first=True)
        self.layer_norm = nn.ModuleList([nn.LayerNorm(768) for _ in range(n_layer)])
        self.mean_pooling = nn.MaxPool2d(kernel_size=(2,2))
        
    def forward(self,x,key_padding_mask=None):
        residual = x
        for multihead_attention,layer_norm in zip(self.layers,self.layer_norm):
           
            x,_ = multihead_attention(x,x,x,key_padding_mask=key_padding_mask)
            x = residual+x
            x = layer_norm(x)
            residual = x
            
        
        x = torch.mean(x,axis=1)
        
        return x
    
class Classifier(nn.Module):
    
    def __init__(self,n_layer=4):
        super().__init__()
        self.userembedder = UserEmbedder(n_layer)
        self.fc = nn.Linear(768,64)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(64,7)
        
    def forward(self,x,src_mask=None):
        x = self.userembedder(x,key_padding_mask=src_mask)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [40]:
def evaluate(model,val_dataloader,device):
    
    print("----- Evaluating ------")
    
    model.eval()
   
    all_predictions = torch.tensor([])
    all_labels = torch.tensor([])
   
    model = model.to(device)
    
    with torch.no_grad(): ## Disable gradient
         for inputs, labels,attn_mask in tqdm(val_dataloader):

            #inputs,labels = batch
            #print(inputs,attn_mask)
            inputs = inputs.to(device)
            labels = labels.to(device)
            attn_mask = attn_mask.to(device)
            
            logits = model(inputs,attn_mask)
            
            
            
            max_index = torch.argmax(logits,axis=-1).cpu()
            pred = torch.nn.functional.one_hot(max_index,7)
            
            labels = labels.cpu()
            all_predictions = torch.cat((all_predictions,pred),axis=0)
            all_labels = torch.cat((all_labels,labels),axis=0)

    class_indices = torch.argmax(all_labels,axis=1)
    prediction_indices = torch.argmax(all_predictions,axis=1)
    f1 = f1_score(all_predictions,all_labels,average=None)
    acc = torch.sum(class_indices==prediction_indices)/len(prediction_indices)
    print(f1)
    print("ACC",acc)
    return {"pred":all_predictions,"labels":all_labels}

In [44]:
model = Classifier(6)
#model = Naive()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
loss_fn = nn.CrossEntropyLoss()
device = 'cuda:1'
loss_fn.to(device)
for i in tqdm(range(25)):
    model.train()
    for inp,label,attn in train_dataloader:
        
        optimizer.zero_grad()

        inp,label,attn = inp.to(device),label.to(device),attn.to(device)
        model = model.to(device)
        logits = model(inp,attn)
        loss = loss_fn(logits,label)
        loss.backward()
        
        optimizer.step()
    evaluate(model,test_dataloader,'cuda:1')
    #print(loss.item())

  0%|                                                                                                                                                                                | 0/25 [00:00<?, ?it/s]

----- Evaluating ------



  0%|                                                                                                                                                                                | 0/60 [00:00<?, ?it/s][A
  7%|███████████▏                                                                                                                                                            | 4/60 [00:00<00:01, 35.26it/s][A
 13%|██████████████████████▍                                                                                                                                                 | 8/60 [00:00<00:01, 36.53it/s][A
 20%|█████████████████████████████████▍                                                                                                                                     | 12/60 [00:00<00:01, 37.05it/s][A
 27%|████████████████████████████████████████████▌                                                                                                                     

[0.18546366 0.00664452 0.34608985 0.34709193 0.64135021 0.19649123
 0.65789474]
ACC tensor(0.3843)
----- Evaluating ------



  1 if key_padding_mask is not None else 0 if attn_mask is not None else None)

  7%|███████████▏                                                                                                                                                            | 4/60 [00:00<00:01, 34.29it/s][A
 13%|██████████████████████▍                                                                                                                                                 | 8/60 [00:00<00:01, 34.73it/s][A
 20%|█████████████████████████████████▍                                                                                                                                     | 12/60 [00:00<00:01, 35.28it/s][A
 27%|████████████████████████████████████████████▌                                                                                                                          | 16/60 [00:00<00:01, 35.77it/s][A
 33%|███████████████████████████████████████████████████████▋                          

[0.245      0.16066482 0.41476274 0.35860656 0.68224299 0.08333333
 0.77674419]
ACC tensor(0.4514)
----- Evaluating ------



  1 if key_padding_mask is not None else 0 if attn_mask is not None else None)

  7%|███████████▏                                                                                                                                                            | 4/60 [00:00<00:01, 33.80it/s][A
 13%|██████████████████████▍                                                                                                                                                 | 8/60 [00:00<00:01, 34.97it/s][A
 20%|█████████████████████████████████▍                                                                                                                                     | 12/60 [00:00<00:01, 35.53it/s][A
 27%|████████████████████████████████████████████▌                                                                                                                          | 16/60 [00:00<00:01, 35.75it/s][A
 33%|███████████████████████████████████████████████████████▋                          

[0.25587467 0.21852732 0.40787623 0.38145695 0.73315364 0.36871508
 0.75462392]
ACC tensor(0.4682)
----- Evaluating ------



  1 if key_padding_mask is not None else 0 if attn_mask is not None else None)

  7%|███████████▏                                                                                                                                                            | 4/60 [00:00<00:01, 34.44it/s][A
 13%|██████████████████████▍                                                                                                                                                 | 8/60 [00:00<00:01, 35.84it/s][A
 20%|█████████████████████████████████▍                                                                                                                                     | 12/60 [00:00<00:01, 36.30it/s][A
 27%|████████████████████████████████████████████▌                                                                                                                          | 16/60 [00:00<00:01, 36.62it/s][A
 33%|███████████████████████████████████████████████████████▋                          

[0.31       0.20967742 0.40357143 0.38993711 0.72727273 0.3950104
 0.77411765]
ACC tensor(0.4835)
----- Evaluating ------



  1 if key_padding_mask is not None else 0 if attn_mask is not None else None)

  7%|███████████▏                                                                                                                                                            | 4/60 [00:00<00:01, 34.77it/s][A
 13%|██████████████████████▍                                                                                                                                                 | 8/60 [00:00<00:01, 35.76it/s][A
 20%|█████████████████████████████████▍                                                                                                                                     | 12/60 [00:00<00:01, 36.28it/s][A
 27%|████████████████████████████████████████████▌                                                                                                                          | 16/60 [00:00<00:01, 36.62it/s][A
 33%|███████████████████████████████████████████████████████▋                          

[0.32894737 0.25728155 0.44       0.37644046 0.72580645 0.3782235
 0.70810811]
ACC tensor(0.4682)


 20%|█████████████████████████████████▌                                                                                                                                      | 5/25 [01:57<07:48, 23.43s/it]


KeyboardInterrupt: 