In [None]:
import ast
import numpy as np
import pandas as pd

import torch
# Import all libraries
import pandas as pd
import numpy as np
import re

# Huggingface transformers
import transformers
from transformers import BertModel,BertTokenizer,AdamW, get_linear_schedule_with_warmup

import torch
from torch import nn ,cuda
from torch.utils.data import DataLoader,Dataset,RandomSampler, SequentialSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

#handling html data
from bs4 import BeautifulSoup

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
%matplotlib inline

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
df = pd.read_csv('final_dataset.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 191253 entries, 0 to 191252
Data columns (total 6 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   CELEX           191253 non-null  object
 1   Text            191253 non-null  object
 2   labels          191253 non-null  object
 3   ohe_labels      191253 non-null  object
 4   gensim_summary  191253 non-null  object
 5   t5_summary      191253 non-null  object
dtypes: object(6)
memory usage: 8.8+ MB


In [None]:
df.head()

Unnamed: 0,CELEX,Text,labels,ohe_labels,gensim_summary,t5_summary
0,21980D1231(03),21980D1231(03) Decision No 3/80 of the EEC-Ice...,Greece|agreement (EU)|accession to the Europea...,"[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",21980D1231(03) Decision No 3/80 of the EEC-Ice...,Article 23 (1) shall be amended by the additio...
1,21986A1115(03),15.11.1986 EN Official Journal of the European...,Portugal|protocol to an agreement|accession to...,"[0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",15.11.1986 EN Official Journal of the European...,15.11.1986 EN Official Journal of the European...
2,21987A0720(02),20.7.1987 EN Official Journal of the European ...,protocol to an agreement|revision of an agreement,"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...",20.7.1987 EN Official Journal of the European ...,20.7.1987 EN Official Journal of the European ...
3,21987D0411(04),21987D0411(04) Decision No 3/86 of the EEC-Swe...,originating product,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0020 - 0024 DECISION N° 3/86 OF THE EEC-SWEDEN...,referred to in paragraph 2: 'The importer's de...
4,21987D0411(05),21987D0411(05) Decision No 3/86 of the EEC-Swi...,originating product,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0026 - 0030 DECISION N° 3/86 OF THE EEC-AUSTRI...,referred to in paragraph 2: 'the exporter's de...


In [None]:
X = df['Text'].values
X_gensim = df['gensim_summary'].values
X_t5 = df['t5_summary'].values

y = np.asarray([[int(val) for val in ohe[1:-1].split(',')] for ohe in df['ohe_labels'].values])

In [None]:
from sklearn.model_selection import train_test_split
# First Split for Train and Test
x_train,x_test, x_gensim_train, x_gensim_test, x_t5_train,x_t5_test, y_train,y_test = train_test_split(X, X_gensim, X_t5, y, test_size=0.2, random_state=101,shuffle=True)
# Next split Train in to training and validation
x_tr, x_val, x_gensim_tr, x_gensim_val, x_t5_tr, x_t5_val, y_tr, y_val = train_test_split(x_train, x_gensim_train, x_t5_train, y_train, test_size=0.2, random_state=101,shuffle=True)

In [None]:
x_gensim_tr.shape ,x_gensim_val.shape, x_gensim_test.shape

In [None]:
y_tr.shape, y_val.shape, y_test.shape

((122401, 91), (30601, 91), (38251, 91))

In [None]:
class QTagDataset (Dataset):
    def __init__(self,quest,tags, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.text = quest
        self.labels = tags
        self.max_len = max_len
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, item_idx):
        text = self.text[item_idx]
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True, # Add [CLS] [SEP]
            max_length= self.max_len,
            padding = 'max_length',
            return_token_type_ids= False,
            return_attention_mask= True, # Differentiates padded vs normal token
            truncation=True, # Truncate data beyond max length
            return_tensors = 'pt' # PyTorch Tensor format
          )
        
        input_ids = inputs['input_ids'].flatten()
        attn_mask = inputs['attention_mask'].flatten()
        #token_type_ids = inputs["token_type_ids"]
        
        return {
            'input_ids': input_ids ,
            'attention_mask': attn_mask,
            'label': torch.tensor(self.labels[item_idx], dtype=torch.float)
            
        } 

In [None]:
class QTagDataModule (pl.LightningDataModule):
    
    def __init__(self,x_tr,y_tr,x_val,y_val,x_test,y_test,tokenizer,batch_size=16,max_token_len=200):
        super().__init__()
        self.tr_text = x_tr
        self.tr_label = y_tr
        self.val_text = x_val
        self.val_label = y_val
        self.test_text = x_test
        self.test_label = y_test
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_token_len = max_token_len

    def setup(self):
        self.train_dataset = QTagDataset(quest=self.tr_text, tags=self.tr_label, tokenizer=self.tokenizer,max_len = self.max_token_len)
        self.val_dataset  = QTagDataset(quest=self.val_text,tags=self.val_label,tokenizer=self.tokenizer,max_len = self.max_token_len)
        self.test_dataset  = QTagDataset(quest=self.test_text,tags=self.test_label,tokenizer=self.tokenizer,max_len = self.max_token_len)
        
        
    def train_dataloader(self):
        return DataLoader (self.train_dataset,batch_size = self.batch_size,shuffle = True , num_workers=32)

    def val_dataloader(self):
        return DataLoader (self.val_dataset,batch_size= 32)

    def test_dataloader(self):
        return DataLoader (self.test_dataset,batch_size= 32)

In [None]:
# Initialize the Bert tokenizer
BERT_MODEL_NAME = "bert-base-cased" # we will use the BERT base model(the smaller one)
Bert_tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

In [None]:
# Initialize the parameters that will be use for training
N_EPOCHS = 12
BATCH_SIZE = 8
MAX_LEN = 512
LR = 2e-05

In [None]:
# Instantiate and set up the data_module
QTdata_module = QTagDataModule(x_gensim_tr,y_tr,x_gensim_val,y_val,x_gensim_test,y_test,Bert_tokenizer,BATCH_SIZE, MAX_LEN)
QTdata_module.setup()

In [None]:
class QTagClassifier(pl.LightningModule):
    # Set up the classifier
    def __init__(self, n_classes=91, steps_per_epoch=None, n_epochs=3, lr=2e-5 ):
        super().__init__()

        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) # outputs = number of labels
        self.steps_per_epoch = steps_per_epoch
        self.n_epochs = n_epochs
        self.lr = lr
        self.criterion = nn.BCEWithLogitsLoss()
#         self.log = 
        
    def forward(self,input_ids, attn_mask):
        output = self.bert(input_ids = input_ids ,attention_mask = attn_mask)
        output = self.classifier(output.pooler_output)
                
        return output
    
    
    def training_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
#         self.log('train_loss',loss , prog_bar=True,logger=True)
        
        return {"loss" :loss, "predictions":outputs, "labels": labels }


    def validation_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
#         self.log('val_loss',loss , prog_bar=True,logger=True)
        
        return loss

    def test_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
#         self.log('test_loss',loss , prog_bar=True,logger=True)
        
        return loss
    
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters() , lr=self.lr)
        warmup_steps = self.steps_per_epoch//3
        total_steps = self.steps_per_epoch * self.n_epochs - warmup_steps

        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,total_steps)

        return [optimizer], [scheduler]

In [None]:
# Instantiate the classifier model
steps_per_epoch = len(x_tr)//BATCH_SIZE
model = QTagClassifier(n_classes=91, steps_per_epoch=steps_per_epoch,n_epochs=N_EPOCHS,lr=LR)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
#Initialize Pytorch Lightning callback for Model checkpointing

# saves a file like: input/QTag-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss', # monitored quantity
    filepath="results/QTag-gensim-{epoch:02d}-{val_loss:.2f}",
    save_top_k=3, #  save the top 3 models
    mode='min', # mode of the monitored quantity  for optimization
)

In [None]:
# Instantiate the Model Trainer
trainer = pl.Trainer(max_epochs = N_EPOCHS , gpus = [1], checkpoint_callback=checkpoint_callback, progress_bar_refresh_rate = 30)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [1]


In [None]:
# Train the Classifier Model
trainer.fit(model, QTdata_module)


  | Name       | Type              | Params
-------------------------------------------------
0 | bert       | BertModel         | 108 M 
1 | classifier | Linear            | 69 K  
2 | criterion  | BCEWithLogitsLoss | 0     


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint..


1

In [None]:
# Evaluate the model performance on the test dataset
trainer.test(model,datamodule=QTdata_module)



Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

In [None]:
# Visualize the logs using tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

# Evaluate Model Performance on Test Set

In [1]:
# Retreive the checkpoint path for best model
model_path = checkpoint_callback.best_model_path

NameError: ignored

In [None]:
len(y_test), len(x_gensim_test)

(38251, 38251)

In [None]:
# Size of Test set
print(f'Number of Text = {len(x_gensim_test)}')

Number of Text = 38251


## setup test dataset for BERT

In [None]:
from torch.utils.data import TensorDataset

# Tokenize all questions in x_test
input_ids = []
attention_masks = []


for quest in x_gensim_test:
    encoded_quest =  Bert_tokenizer.encode_plus(
                    quest,
                    None,
                    add_special_tokens=True,
                    max_length= MAX_LEN,
                    padding = 'max_length',
                    return_token_type_ids= False,
                    return_attention_mask= True,
                    truncation=True,
                    return_tensors = 'pt'      
    )
    
    # Add the input_ids from encoded question to the list.    
    input_ids.append(encoded_quest['input_ids'])
    # Add its attention mask 
    attention_masks.append(encoded_quest['attention_mask'])
    
# Now convert the lists into tensors.
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(y_test)

# Set the batch size.  
TEST_BATCH_SIZE = 64  

# Create the DataLoader.
pred_data = TensorDataset(input_ids, attention_masks, labels)
pred_sampler = SequentialSampler(pred_data)
pred_dataloader = DataLoader(pred_data, sampler=pred_sampler, batch_size=TEST_BATCH_SIZE)
    

In [None]:
pred_data[0]

(tensor([  101,  1853,   119,   128,   119,  1349,   142,  2249,  9018,  3603,
          1104,  1103,  1735,  1913,   149, 21764,   120,  1275, 18732, 25290,
          6258, 13882, 11414,   146, 12347, 17516, 14424, 15681, 15740,   155,
         17020,  2591, 10783, 21669, 11414,   113,  7270,   114,  1302,  5692,
          1559,   120,  1349,  1104,  1743,  1351,  1349,  1113,  1103,  5867,
         10148,  4019,  1106,  1129,  4275,  1107,  2593,  1106,  1103,  1248,
          7597,  8727,  1106,  8886,  1439,  1103,  8886,  1158,  7791,  1533,
          1118,   146, 26318,  1880,  1158, 22575,   113,  7270,   114,  1302,
          5519,  1527,   120,  1349,  7462,  7270, 21564,  2101, 12420,  2249,
         18732, 25290,  6258, 13882, 11414,   117,  5823,  7328,  1106,  1103,
          6599,  1113,  1103, 16068,  5796,  1158,  1104,  1103,  1735,  1913,
           117,  5823,  7328,  1106,  1761, 22575,   113, 16028,   114,  1302,
         13414,  1527,   120,  1384,  1104,  1659,  

In [None]:
len(pred_dataloader.dataset)

38251

## Prediction on test set

In [None]:
flat_pred_outs = 0
flat_true_labels = 0

In [None]:
# Put model in evaluation mode
model = model.to(device) # moving model to cuda
model.eval()

# Tracking variables 
pred_outs, true_labels = [], []
#i=0
# Predict 
for batch in pred_dataloader:
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
  
    # Unpack the inputs from our dataloader
    b_input_ids, b_attn_mask, b_labels = batch
 
    with torch.no_grad():
        # Forward pass, calculate logit predictions
        pred_out = model(b_input_ids,b_attn_mask)
        pred_out = torch.sigmoid(pred_out)
        # Move predicted output and labels to CPU
        pred_out = pred_out.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        #i+=1
        # Store predictions and true labels
        #print(i)
        #print(outputs)
        #print(logits)
        #print(label_ids)
    pred_outs.append(pred_out)
    true_labels.append(label_ids)

In [None]:
pred_outs[0][0]

array([0.04393886, 0.04026005, 0.03702657, 0.06509363, 0.03968716,
       0.02838152, 0.05158657, 0.0406674 , 0.04223459, 0.0387723 ,
       0.0439634 , 0.03686194, 0.05175484, 0.02595243, 0.03147109,
       0.07856017, 0.04027737, 0.04643226, 0.02796185, 0.04288817,
       0.03926515, 0.0493609 , 0.03507802, 0.05418249, 0.06684761,
       0.04126413, 0.05177863, 0.02937672, 0.02752749, 0.05514112,
       0.02625699, 0.05497538, 0.02959226, 0.04898583, 0.02912938,
       0.03330458, 0.03126326, 0.04742906, 0.04016345, 0.06085854,
       0.03757559, 0.02863928, 0.0272129 , 0.03694598, 0.04670656,
       0.05081735, 0.04376685, 0.03831787, 0.04239337, 0.05047482,
       0.02928953, 0.0396766 , 0.06822337, 0.04596595, 0.07472587,
       0.03266276, 0.04203046, 0.03458502, 0.06684907, 0.03215622,
       0.0481158 , 0.05603847, 0.03528995, 0.05482062, 0.03447564,
       0.04375356, 0.02653649, 0.03402059, 0.03573639, 0.06251746,
       0.03735632, 0.05038344, 0.04174133, 0.0638653 , 0.03893

In [None]:
# Combine the results across all batches. 
flat_pred_outs = np.concatenate(pred_outs, axis=0)

# Combine the correct labels for each batch into a single list.
flat_true_labels = np.concatenate(true_labels, axis=0)

In [None]:
flat_pred_outs.shape , flat_true_labels.shape

((38251, 91), (38251, 91))

## Predictions of Tags in Test set

> The predictions are in terms of logits (probabilities for each of the 16 tags). Hence we need to have a threshold value to convert these probabilities to 0 or 1.

> Let's specify a set of candidate threshold values. We will select the threshold value that performs the best for the test set.

In [None]:
#define candidate threshold values
threshold  = np.linspace(0.01, 0.11, 10)
threshold

array([0.01      , 0.02111111, 0.03222222, 0.04333333, 0.05444444,
       0.06555556, 0.07666667, 0.08777778, 0.09888889, 0.11      ])

> Let's define a function that takes a threshold value and uses it to convert probabilities into 1 or 0.

In [None]:
# convert probabilities into 0 or 1 based on a threshold value
def classify(pred_prob,thresh):
    y_pred = []

    for tag_label_row in pred_prob:
        temp=[]
        for tag_label in tag_label_row:
            if tag_label >= thresh:
                temp.append(1) # Infer tag value as 1 (present)
            else:
                temp.append(0) # Infer tag value as 0 (absent)
        y_pred.append(temp)

    return y_pred

In [None]:
flat_pred_outs[3]

array([0.04465479, 0.04000388, 0.03793235, 0.064553  , 0.04038289,
       0.02868011, 0.05141482, 0.04039403, 0.04085653, 0.03885972,
       0.04467247, 0.03720953, 0.05279379, 0.02677394, 0.03196805,
       0.07943403, 0.04075136, 0.0465932 , 0.02802228, 0.04379119,
       0.03988887, 0.04814532, 0.03448923, 0.05348637, 0.06770615,
       0.04251643, 0.05299169, 0.0310928 , 0.02766096, 0.05564284,
       0.02683596, 0.05448954, 0.03006033, 0.04838379, 0.02943339,
       0.03272092, 0.03122783, 0.04747922, 0.04034546, 0.06104731,
       0.03917881, 0.0277924 , 0.02774214, 0.03724122, 0.04618394,
       0.04954909, 0.04315152, 0.03886757, 0.04151737, 0.04929136,
       0.02934361, 0.03959372, 0.06875823, 0.04386379, 0.0723106 ,
       0.03326748, 0.04428256, 0.03480516, 0.06830633, 0.03237043,
       0.04901989, 0.05346571, 0.03514029, 0.05447625, 0.03535723,
       0.04356833, 0.02636186, 0.03323644, 0.03564078, 0.06113088,
       0.03721401, 0.04981613, 0.0418331 , 0.06664433, 0.03810

In [None]:
flat_true_labels[3]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
       0, 0, 0])

In [None]:
flat_pred_outs[10]

array([0.0447965 , 0.04006191, 0.03797207, 0.06443455, 0.04063178,
       0.02868171, 0.05162624, 0.0405179 , 0.04112419, 0.03896799,
       0.04485374, 0.03778259, 0.05278785, 0.02686218, 0.03209111,
       0.07981121, 0.04031513, 0.04681759, 0.02811153, 0.04394386,
       0.04026797, 0.04798578, 0.03466038, 0.05378013, 0.06769781,
       0.04268819, 0.05293808, 0.03133454, 0.02777813, 0.05590944,
       0.02686482, 0.05400006, 0.03000082, 0.04816921, 0.02965449,
       0.03264096, 0.03143499, 0.04771715, 0.04037525, 0.06121652,
       0.03920124, 0.0277793 , 0.02796538, 0.03731326, 0.04606533,
       0.04943749, 0.04330076, 0.03925463, 0.0412901 , 0.04893321,
       0.0293907 , 0.03949533, 0.0687308 , 0.04356764, 0.07196701,
       0.03333909, 0.04489695, 0.03494585, 0.0687817 , 0.03241259,
       0.04948201, 0.05340688, 0.03508839, 0.05515257, 0.03553026,
       0.04379005, 0.02629385, 0.0331658 , 0.03554313, 0.06083254,
       0.03721584, 0.04952275, 0.04167248, 0.0675009 , 0.03808

In [None]:
flat_true_labels[10]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
       0, 0, 0])

In [None]:
from sklearn import metrics
scores=[] # Store the list of f1 scores for prediction on each threshold

#convert labels to 1D array
y_true = flat_true_labels.ravel() 

for thresh in threshold:
    
    #classes for each threshold
    pred_bin_label = classify(flat_pred_outs,thresh) 

    #convert to 1D array
    y_pred = np.array(pred_bin_label).ravel()

    scores.append(metrics.f1_score(y_true,y_pred))

In [None]:
# find the optimal threshold
opt_thresh = threshold[scores.index(max(scores))]
print(f'Optimal Threshold Value = {opt_thresh}')

Optimal Threshold Value = 0.06555555555555556


## Performance Score Evaluation

In [None]:
#predictions for optimal threshold
y_pred_labels = classify(flat_pred_outs,opt_thresh)
y_pred = np.array(y_pred_labels).ravel() # Flatten

In [None]:
print(metrics.classification_report(y_true,y_pred))

              precision    recall  f1-score   support

           0       0.98      0.93      0.95   3398632
           1       0.05      0.15      0.08     82209

    accuracy                           0.91   3480841
   macro avg       0.51      0.54      0.52   3480841
weighted avg       0.96      0.91      0.93   3480841



In [None]:
y = np.asarray([np.asarray(labels.split('|'), dtype='object') for labels in df['labels']], dtype='object')
y_ohe = np.asarray([[int(x) for x in labels[1:-1].split(',')] for labels in df['ohe_labels']])
labellst = []
for i in range(91):    
    llset = set(y[np.where(y_ohe[:, i]==1)][0])
    for ll in y[np.where(y_ohe[:, i]==1)]:
        llset = llset.intersection(ll)
    labellst.append(list(llset)[0])
labellst = np.asarray(labellst)
labellst

array(['Greece', 'agreement (EU)', 'accession to the European Union',
       'originating product', 'Portugal', 'protocol to an agreement',
       'Spain', 'revision of an agreement', 'EU programme',
       'third country', 'European Economic Area', 'fishing area',
       'sea fish', 'fishing rights', 'catch quota',
       'economic concentration', 'labelling', 'approximation of laws',
       'consumer information', 'public health',
       'environmental protection', 'export refund', 'action programme',
       'technical standard', 'international sanctions', 'human rights',
       'import', 'import licence', 'marketing', 'veterinary inspection',
       'agricultural product', 'exchange of information',
       'marketing standard', 'foodstuff', 'consumer protection',
       'cooperation policy', 'air transport', 'EU aid', 'EU financing',
       'EU Member State', 'provision of services',
       'disclosure of information', 'chemical product', 'indemnification',
       'health control', 

In [None]:
len(labellst)

91

In [None]:
def inverse_labels(arr):
    inv_labels = []
    for labels in arr:
        inv_labels.append(labellst[np.where(labels)])        
    return np.asarray(inv_labels, dtype='object')

In [None]:
y_pred = inverse_labels(np.array(y_pred_labels))
y_act = inverse_labels(flat_true_labels)

In [None]:
y_pred[0]

array(['economic concentration', 'international sanctions', 'State aid',
       'Italy', 'euro', 'merger control'], dtype='<U42')

In [None]:
df_pred = pd.DataFrame({'Body':x_test,'Actual Tags':y_act,'Predicted Tags':y_pred})

In [None]:
df_pred

Unnamed: 0,Body,Actual Tags,Predicted Tags
0,29.7.2011 EN Official Journal of the European ...,"[import licence, import (EU)]","[economic concentration, international sanctio..."
1,31.7.2010 EN Official Journal of the European ...,"[State aid, Italy, control of State aid, Europ...","[economic concentration, international sanctio..."
2,17.4.2014 EN Official Journal of the European ...,"[import (EU), common organisation of markets, ...","[originating product, economic concentration, ..."
3,30.3.2015 EN Official Journal of the European ...,"[EU trade mark, trademark law, registered trad...","[economic concentration, international sanctio..."
4,16.11.2017 EN Official Journal of the European...,"[revision of an agreement, European Economic A...","[economic concentration, international sanctio..."
...,...,...,...
38246,27.3.2004 EN Official Journal of the European ...,[EU financing],"[economic concentration, international sanctio..."
38247,32001R2522 Commission Regulation (EC) No 2522/...,"[third country, export refund, award of contract]","[economic concentration, international sanctio..."
38248,5.2.2018 EN Official Journal of the European U...,"[European trademark, trademark law, registered...","[economic concentration, international sanctio..."
38249,"EUROPEAN COMMISSION Brussels,14.6.2018 COM(201...",[EU aid],"[economic concentration, international sanctio..."


In [None]:
df_pred.to_csv('gensim_pred.csv', index = False)

In [None]:
x_gensim_test_unique, x_gensim_test_idx = np.unique(x_gensim_test, return_index = True)

In [None]:
Ids = df.iloc[df[df['gensim_summary'].isin(x_gensim_test_unique)]['gensim_summary'].drop_duplicates().index]['CELEX']

# DataFrame for Top-5 predictions

In [None]:
predictions = pd.DataFrame(columns = labellst, data = flat_pred_outs[x_gensim_test_idx], index = Ids)
predictions.reset_index(inplace = True)
predictions.rename(columns = {'CELEX': 'id'}, inplace = True)
predictions

Unnamed: 0,id,Greece,agreement (EU),accession to the European Union,originating product,Portugal,protocol to an agreement,Spain,revision of an agreement,EU programme,...,financial aid,investment company,EU trade mark,equal treatment,European trademark,trademark law,registered trademark,interpretation of the law,action for failure to fulfil an obligation,action for annulment of an EC decision
0,21987D1231(03),0.043448,0.042596,0.037585,0.067159,0.039057,0.028681,0.051261,0.040938,0.043500,...,0.022925,0.035487,0.043445,0.033151,0.041113,0.045127,0.061686,0.045624,0.050990,0.040495
1,21991D1112(07),0.045302,0.040721,0.037804,0.065650,0.039498,0.028315,0.051659,0.040118,0.041728,...,0.023505,0.034755,0.040871,0.033427,0.041346,0.048056,0.062494,0.045464,0.050888,0.041950
2,21994D1231(13),0.044837,0.040459,0.037286,0.064526,0.039741,0.028375,0.051865,0.040463,0.041820,...,0.023065,0.035219,0.040636,0.033275,0.041277,0.046390,0.062082,0.044831,0.050361,0.041596
3,21997D0327(03),0.044199,0.040051,0.036876,0.064916,0.039842,0.028434,0.051921,0.040423,0.042400,...,0.022680,0.035545,0.040601,0.033118,0.041387,0.045485,0.061071,0.044831,0.050208,0.040697
4,21997D0710(18),0.044471,0.040117,0.036844,0.064793,0.039791,0.028467,0.051989,0.040480,0.042302,...,0.022796,0.035529,0.040120,0.033305,0.041524,0.045841,0.061181,0.044847,0.049863,0.041090
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
38209,E2010C0159,0.044110,0.040014,0.037224,0.064809,0.039672,0.028617,0.051738,0.040605,0.042147,...,0.022806,0.035461,0.040442,0.033062,0.041371,0.046028,0.061505,0.044847,0.050113,0.041010
38210,E2010C1209(02),0.044147,0.040336,0.037258,0.065080,0.039489,0.028420,0.051726,0.040681,0.042330,...,0.022849,0.035246,0.040858,0.032780,0.041226,0.045815,0.061703,0.044937,0.050382,0.041049
38211,E2012C0112(03),0.044122,0.040052,0.036995,0.065162,0.039691,0.028245,0.051657,0.040413,0.042082,...,0.022759,0.035381,0.040319,0.033276,0.041456,0.045933,0.061319,0.044790,0.050057,0.040948
38212,E2013C0214(02),0.044144,0.040023,0.037149,0.065022,0.039781,0.028664,0.051816,0.040661,0.042316,...,0.022832,0.035459,0.040303,0.033228,0.041474,0.046257,0.061637,0.044754,0.049917,0.041127


In [None]:
predictions.set_index('id', inplace = True)
predictions.T

id,21987D1231(03),21991D1112(07),21994D1231(13),21997D0327(03),21997D0710(18),21999A0716(01),22000P0301(14),22003D0041,22007D0137,22010A0622(03),...,C2017/082/01,C2017/301/03,C2019/207/02,C2019/373/03,E2004C0305R(01),E2010C0159,E2010C1209(02),E2012C0112(03),E2013C0214(02),E2019P0003
Greece,0.043448,0.045302,0.044837,0.044199,0.044471,0.043862,0.044017,0.044176,0.044664,0.044599,...,0.044531,0.043913,0.044896,0.044256,0.044256,0.044110,0.044147,0.044122,0.044144,0.043962
agreement (EU),0.042596,0.040721,0.040459,0.040051,0.040117,0.040588,0.040324,0.040472,0.040027,0.040038,...,0.041244,0.040017,0.040298,0.039936,0.040169,0.040014,0.040336,0.040052,0.040023,0.040128
accession to the European Union,0.037585,0.037804,0.037286,0.036876,0.036844,0.037295,0.036968,0.037310,0.037237,0.037371,...,0.037725,0.036907,0.037455,0.036956,0.037064,0.037224,0.037258,0.036995,0.037149,0.037074
originating product,0.067159,0.065650,0.064526,0.064916,0.064793,0.065816,0.065291,0.065213,0.064929,0.065113,...,0.066616,0.064955,0.064867,0.064736,0.064775,0.064809,0.065080,0.065162,0.065022,0.065140
Portugal,0.039057,0.039498,0.039741,0.039842,0.039791,0.039772,0.039538,0.039611,0.039896,0.039758,...,0.039408,0.040006,0.039904,0.039925,0.039752,0.039672,0.039489,0.039691,0.039781,0.039644
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
trademark law,0.045127,0.048056,0.046390,0.045485,0.045841,0.046028,0.045661,0.046033,0.047156,0.047059,...,0.047147,0.045823,0.047160,0.046004,0.045861,0.046028,0.045815,0.045933,0.046257,0.045952
registered trademark,0.061686,0.062494,0.062082,0.061071,0.061181,0.061482,0.061324,0.061334,0.061907,0.061986,...,0.062237,0.061441,0.062161,0.061314,0.061446,0.061505,0.061703,0.061319,0.061637,0.061408
interpretation of the law,0.045624,0.045464,0.044831,0.044831,0.044847,0.044957,0.044879,0.045135,0.045077,0.045080,...,0.045455,0.044970,0.044732,0.044686,0.044876,0.044847,0.044937,0.044790,0.044754,0.044953
action for failure to fulfil an obligation,0.050990,0.050888,0.050361,0.050208,0.049863,0.050629,0.050193,0.050324,0.050017,0.050351,...,0.051129,0.050047,0.050250,0.050042,0.049918,0.050113,0.050382,0.050057,0.049917,0.050205


In [None]:
pred_transform = predictions.T.copy()
pred_transform

id,21987D1231(03),21991D1112(07),21994D1231(13),21997D0327(03),21997D0710(18),21999A0716(01),22000P0301(14),22003D0041,22007D0137,22010A0622(03),...,C2017/082/01,C2017/301/03,C2019/207/02,C2019/373/03,E2004C0305R(01),E2010C0159,E2010C1209(02),E2012C0112(03),E2013C0214(02),E2019P0003
Greece,0.043448,0.045302,0.044837,0.044199,0.044471,0.043862,0.044017,0.044176,0.044664,0.044599,...,0.044531,0.043913,0.044896,0.044256,0.044256,0.044110,0.044147,0.044122,0.044144,0.043962
agreement (EU),0.042596,0.040721,0.040459,0.040051,0.040117,0.040588,0.040324,0.040472,0.040027,0.040038,...,0.041244,0.040017,0.040298,0.039936,0.040169,0.040014,0.040336,0.040052,0.040023,0.040128
accession to the European Union,0.037585,0.037804,0.037286,0.036876,0.036844,0.037295,0.036968,0.037310,0.037237,0.037371,...,0.037725,0.036907,0.037455,0.036956,0.037064,0.037224,0.037258,0.036995,0.037149,0.037074
originating product,0.067159,0.065650,0.064526,0.064916,0.064793,0.065816,0.065291,0.065213,0.064929,0.065113,...,0.066616,0.064955,0.064867,0.064736,0.064775,0.064809,0.065080,0.065162,0.065022,0.065140
Portugal,0.039057,0.039498,0.039741,0.039842,0.039791,0.039772,0.039538,0.039611,0.039896,0.039758,...,0.039408,0.040006,0.039904,0.039925,0.039752,0.039672,0.039489,0.039691,0.039781,0.039644
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
trademark law,0.045127,0.048056,0.046390,0.045485,0.045841,0.046028,0.045661,0.046033,0.047156,0.047059,...,0.047147,0.045823,0.047160,0.046004,0.045861,0.046028,0.045815,0.045933,0.046257,0.045952
registered trademark,0.061686,0.062494,0.062082,0.061071,0.061181,0.061482,0.061324,0.061334,0.061907,0.061986,...,0.062237,0.061441,0.062161,0.061314,0.061446,0.061505,0.061703,0.061319,0.061637,0.061408
interpretation of the law,0.045624,0.045464,0.044831,0.044831,0.044847,0.044957,0.044879,0.045135,0.045077,0.045080,...,0.045455,0.044970,0.044732,0.044686,0.044876,0.044847,0.044937,0.044790,0.044754,0.044953
action for failure to fulfil an obligation,0.050990,0.050888,0.050361,0.050208,0.049863,0.050629,0.050193,0.050324,0.050017,0.050351,...,0.051129,0.050047,0.050250,0.050042,0.049918,0.050113,0.050382,0.050057,0.049917,0.050205


In [None]:
top_5 = pd.DataFrame(index = predictions.index, columns = ['Top_1', 'Top_2', 'Top_3', 'Top_4', 'Top_5', 'labels'])

for col in pred_transform.columns:
    labels = df[df['CELEX'] == col]['labels']
    top_5.loc[col] = list(predictions.T[col].nlargest(5).items()) + [tuple(labels)]
    
top_5

  arr_value = np.asarray(value)


Unnamed: 0_level_0,Top_1,Top_2,Top_3,Top_4,Top_5,labels
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
21987D1231(03),"(economic concentration, 0.07823850214481354)","(Italy, 0.07719823718070984)","(merger control, 0.07299649715423584)","(State aid, 0.0693913921713829)","(international sanctions, 0.06766422092914581)","(originating product,)"
21991D1112(07),"(economic concentration, 0.07841677963733673)","(merger control, 0.07454001158475876)","(Italy, 0.07278908789157867)","(euro, 0.0697210505604744)","(State aid, 0.06909801810979843)","(agreement (EU)|originating product,)"
21994D1231(13),"(economic concentration, 0.07873343676328659)","(Italy, 0.07383434474468231)","(merger control, 0.07362982630729675)","(euro, 0.06873254477977753)","(State aid, 0.06867988407611847)","(European Economic Area,)"
21997D0327(03),"(economic concentration, 0.07898244261741638)","(Italy, 0.07521649450063705)","(merger control, 0.07372426986694336)","(State aid, 0.06824901700019836)","(euro, 0.06729411333799362)","(European Economic Area|agreement (EU),)"
21997D0710(18),"(economic concentration, 0.07853732258081436)","(Italy, 0.07401464879512787)","(merger control, 0.07347004860639572)","(euro, 0.06809794902801514)","(State aid, 0.06798099726438522)",(European Economic Area|public health|environm...
...,...,...,...,...,...,...
E2010C0159,"(economic concentration, 0.07913943380117416)","(Italy, 0.07442620396614075)","(merger control, 0.07432220131158829)","(State aid, 0.06783153861761093)","(euro, 0.06771905720233917)","(health control,)"
E2010C1209(02),"(economic concentration, 0.07917667925357819)","(Italy, 0.07446056604385376)","(merger control, 0.07390908151865005)","(State aid, 0.06804364174604416)","(euro, 0.06774389743804932)","(financial aid|control of State aid|State aid,)"
E2012C0112(03),"(economic concentration, 0.07870585471391678)","(Italy, 0.07395709306001663)","(merger control, 0.0738767459988594)","(State aid, 0.06848016381263733)","(euro, 0.06762175261974335)","(control of State aid|State aid,)"
E2013C0214(02),"(economic concentration, 0.07905878871679306)","(merger control, 0.07412014901638031)","(Italy, 0.07383929938077927)","(euro, 0.06826850026845932)","(State aid, 0.06786638498306274)","(State aid|control of State aid,)"


In [None]:
top_5.to_csv('gensim_top_5_pred.csv')