# Library Installation

In [1]:
# !pip install spacy
# !python -m spacy download en_core_web_sm
# !pip install datasets
# !pip install wordninja
# !pip install textblob
# !pip install nltk
# !pip install sentence-transformers
# import nltk
# nltk.download('punkt')
# nltk.download('wordnet')
# nltk.download('averaged_perceptron_tagger')
# !pip install swifter
# !pip install pyspellchecker

# Import Libraries + Load Models

In [2]:
import pandas as pd, numpy as np
import re
import swifter
import gensim.downloader as api
import spacy
import en_core_web_sm
import string
import wordninja
import torch
import torch.nn as nn
import warnings

from datasets import load_dataset
from spellchecker import SpellChecker
from nltk.stem import WordNetLemmatizer
from textblob import Word
from nltk.corpus import wordnet
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel
from transformers import DistilBertForSequenceClassification,DistilBertModel, DistilBertTokenizer
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.exceptions import UndefinedMetricWarning

# suppress UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)


word2vec_model = api.load("word2vec-google-news-300")
nlp = spacy.load('en_core_web_sm')

# Fetch Dataset

In [3]:
load_dataset("sem_eval_2010_task_8")

DatasetDict({
    train: Dataset({
        features: ['sentence', 'relation'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['sentence', 'relation'],
        num_rows: 2717
    })
})

NOTE: This dataset doesnot contain any validation set, hence validation is splitting the training set

In [4]:
train_df = pd.DataFrame(load_dataset("sem_eval_2010_task_8", split = "train"))

In [5]:
test_df = pd.DataFrame(load_dataset("sem_eval_2010_task_8", split = "test"))

# Preprocess Data

Functions to extract features, improve data quality and more

### 1. Extract Entities

In [6]:
train_df.iloc[0]['sentence']

'The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>.'

Sentence contains entities enclosed in '<>'

In [7]:
def extract_entities(sentence):
    try:
        e1 = re.search(r'<e1>(.*?)</e1>', sentence).group(1).lower().strip()
        e2 = re.search(r'<e2>(.*?)</e2>', sentence).group(1).lower().strip()
    except:
        # raise error if entities are not enclosed in '<>' in sentence
        raise ValueError('Sentence passed is not in correct format')
    return pd.Series([e1, e2], index=['e1', 'e2'])

### 2. Synset Check

In [8]:
spell_checker = SpellChecker(distance=1)
lemmatizer = WordNetLemmatizer()
count = 0

def synset_check(text,sentence):
    global count
    return_text = ''
    
    try:
        text = int(text[0])
        text = 'number'
    except:
        pass
   
    if wordnet.synsets(text): # 1. search text 
        return_text = text
    
    elif wordnet.synsets(text.capitalize()): # 2. search capitalised text 
        return_text = text.capitalize()
        
    elif wordnet.synsets(text.replace('-','')): # 3. remove hyphen and search
        return_text = text.replace('-','')
        
    elif wordnet.synsets(str(lemmatizer.lemmatize(text, pos='n'))): # 4. convert to singular and then search
        return_text = str(lemmatizer.lemmatize(text, pos='n'))
        
    elif wordnet.synsets(str(lemmatizer.lemmatize(text, pos='n')).capitalize()): # 5. convert to singular & capitalise and then search
        return_text = str(lemmatizer.lemmatize(text, pos='n')).capitalize()
        
    elif wordnet.synsets(str(Word(text).lemmatize())): # 6. convert past to present and search
        return_text = str(Word(text).lemmatize())
    
    elif wordnet.synsets(text.replace('-',' ').split()[-1]): # 7. if >1 words, search last word
        return_text = text.replace('-',' ').split()[-1]
        
    elif wordnet.synsets(text.replace('-',' ').split()[0]): # 8. if >1 words, search first word
        return_text = text.replace('-',' ').split()[0]
        
    # some custom defined rules
    # 1. 'er' cases
    elif str(Word(text).lemmatize())[-2:]=='er' and wordnet.synsets(str(Word(text).lemmatize())[:-2]): # 9. remove last 'er' of word and search
        return_text = str(Word(text).lemmatize())[:-2]
        
    # 2. 'ment' cases
    elif str(Word(text).lemmatize())[-4:]=='ment' and wordnet.synsets(str(Word(text).lemmatize())[:-4]): # 10. remove last 'ment' of word and search
        return_text = str(Word(text).lemmatize())[:-4]
    
    elif wordnet.synsets(str(spell_checker.correction(text))): # 11. correct spelling, if any, and search
        return_text = str(spell_checker.correction(text))
    
    elif wordninja.split(nlp(text)[0].lemma_)!=[] and wordnet.synsets(max(wordninja.split(nlp(text)[0].lemma_), key=len)): # 12. split words into segment and search the longest word
        return_text = max(wordninja.split(nlp(text)[0].lemma_), key=len)

    elif wordninja.split(nlp(text)[0].lemma_)!=[] and wordnet.synsets(str(wordninja.split(nlp(text)[0].lemma_)[-1])): # 13. split words into segment and search the last word
        return_text = str(wordninja.split(nlp(text)[0].lemma_)[-1])
    
    elif wordninja.split(nlp(text)[0].lemma_)!=[] and wordnet.synsets(str(wordninja.split(nlp(text)[0].lemma_)[0])): # 14. split words into segment and search the first word
        return_text = str(wordninja.split(nlp(text)[0].lemma_)[0])

    else:
        pass    
    
    if return_text =='':
        count = count+1
        return_text = text

    return_sentence = sentence.replace(text, return_text)
    match = re.search(r"<e1>(.*?)</e2>", return_sentence)
    if match:
        trimmed_sentence = (match.group(1))  # Extract the captured group (trip information)
    else:
        trimmed_sentence = None
    
    # remove entity tags '<>'
    return_sentence = re.sub(r"<[^>]+>", "",return_sentence)
    trimmed_sentence = re.sub(r"<[^>]+>", "",trimmed_sentence)
    # remove punctuations
    punctuation_set = ''.join(char for char in string.punctuation)
    return_sentence = return_sentence.translate(str.maketrans('', '', punctuation_set))
    trimmed_sentence = trimmed_sentence.translate(str.maketrans('', '', punctuation_set))
    # remove extra spaces
    return_sentence = re.sub(r"\s+", " ", return_sentence).strip()
    trimmed_sentence = re.sub(r"\s+", " ", trimmed_sentence).strip()
        
    
    return pd.Series([return_sentence, return_text, trimmed_sentence], index=['corrected_sentence', 'e', 'trimmed_sentence'])
    

### 3. Custom Entity Disambiguation 

In [9]:
def segment_synsets(word):
    words = word.replace('-',' ').split(' ')
    new_words = []
    
    for item in words:
        try:
            new_words = new_words + wordninja.split(nlp(item)[0].lemma_)
        except:
            pass
        
    for j in words:
        for i in range(1,len(j)):
            prefix = j[:i]
            suffix = j[i:]
            if wordnet.synsets(prefix) and wordnet.synsets(suffix):

                new_words = [prefix,suffix] + new_words
            elif wordnet.synsets(prefix) and len(prefix)>=5:
                new_words = [prefix] + new_words
            elif wordnet.synsets(suffix) and len(suffix)>=5:
                new_words = [suffix] + new_words

    new_words = [item for item in new_words if len(item) >= 3]
        
    synsets = []
    for item in new_words:
        item_synset = wordnet.synsets(item)
        synsets = synsets + item_synset
    return synsets

    

def filter_main_tokens(sentence):

    doc = nlp(sentence)
    # filtering based on POS tags
    filtered_tokens = [token.text for token in doc if token.pos_ in ("NOUN", "VERB", "ADJ")]

    return filtered_tokens
    

def disambiguate_entity_in_sentence(sentence, word, flag):
    
    # main tokens of the sentence
    sentence_tokens = filter_main_tokens(sentence.replace(word,''))
    
    # fetch synsets for the word
    synsets = wordnet.synsets(word.replace(' ','_'))
    
    # handle british vs american english
    if len(synsets)==0 and word.count('s')==1:
        synsets = wordnet.synsets(word.replace('s','z'))
    if len(synsets)==0 and word.count('z')==1:
        synsets = wordnet.synsets(word.replace('z','s'))
        
    # custom synset extraction by segmenting words
    if len(synsets)==0:
        synsets = segment_synsets(word)
    
    if len(synsets)==0:
        print ('Synset for "'+word+'" NOT FOUND')  
    
    # fetch the most relatable synset based on sentence
    scores = {}
    for synset in synsets:
        synset_tokens = filter_main_tokens(synset.definition())

        synset_embeddings = [word2vec_model[token] for token in synset_tokens if token in word2vec_model]
        
        if len(synset_embeddings)>0:
            avg_synset_embedding = np.mean(synset_embeddings, axis=0)
            # calculate similarity score based on cosine similarity between avg_synset_embedding and each token in sentence
            similarity_scores = [np.dot(avg_synset_embedding, word2vec_model[token])/(np.linalg.norm(avg_synset_embedding)*np.linalg.norm(word2vec_model[token]))
                                 for token in sentence_tokens if token in word2vec_model]
            scores[synset] = np.mean(similarity_scores)
        
        else:
            scores[synset] = 0
    
    # extract the highest score synset
    if len(synsets)>0:
        best_synset = max(scores, key=scores.get)
        
    else:
        best_synset = wordnet.synsets('unavailable')[0]
        flag = flag+1
    e_definition = best_synset.definition()
    
    return pd.Series([best_synset, e_definition, flag], index=['e_synset', 'e_definition', 'flag'])


In [10]:
def get_hypernym(synset1, synset2):
    common_hypernym = synset1.lowest_common_hypernyms(synset2)
    return common_hypernym[0].lemmas()[0].name() if common_hypernym else None

### 4. Extract Features

In [11]:
def extract_features(e1, e2, sentence):

    
    doc = nlp(sentence)

    e1_dep_token, e2_dep_token = 'NA.', 'NA.'
    e1_prev_token, e2_prev_token = 'NA.', 'NA.'
    e1_post_token, e2_post_token = 'NA.', 'NA.'

    e1_post_memory = e2_post_memory = False
    memory = 'NA.'

    for token in doc:
        
        if e1_post_memory == True:
            e1_post_token = str(token)
        e1_post_memory = False
        
        if e2_post_memory == True:
            e2_post_token = str(token)
        e2_post_memory = False
        
        if str(token) in e1.split():            
            e1_dep_token = str(token.head)
            
            if e1_prev_token == 'NA.':
                e1_prev_token = memory
                
            e1_post_memory = True
            
        if str(token) in e2.split(): 
            e2_dep_token = str(token.head)
            
            if e2_prev_token == 'NA.':
                e2_prev_token = memory
                
            e2_post_memory = True

        memory = str(token)
    
    return pd.Series([e1_dep_token, e1_prev_token, e1_post_token, e2_dep_token, e2_prev_token, e2_post_token],
                     index=['e1_dep_token', 'e1_prev_token', 'e1_post_token', 'e2_dep_token', 'e2_prev_token', 'e2_post_token'])


### 5. Sentence Tuning

In [12]:
def tune_sentence(new_sentence, e1, e2, e1_def, e2_def):
    
    doc = nlp(new_sentence)
    return_sentence = ''
    prev_word_pos = ''
    prev_word = ''
    sentence_entry = False
    
    # strip sentence between entities and select important tokens
    for word in doc:
        if str(word) == e1 and sentence_entry == False:
            sentence_entry = True
            end_type = e2
        if str(word) == e2 and sentence_entry == False:
            sentence_entry = True
            end_type = e1
            
        if (str(word) == e1 or str(word) == e2 or word.pos_ in ['ADP', 'AUX','VERB', 'X']) and sentence_entry==True:

            if word.pos_ =='ADP' and not(prev_word_pos in ['AUX','VERB', 'X'] or prev_word == e1 or prev_word == e2):
                pass
            else:
                return_sentence = return_sentence + ' ' + word.text
            
            prev_word_pos = word.pos_
            prev_word = str(word)
            
            if end_type == str(word):
                sentence_entry = False
                break
                
    # extract entity definition tokens
    def definition_tokens(doc):
        new_def = ''
        for word in doc:
            if (word.pos_ in ['ADP', 'AUX','VERB', 'X', 'NOUN']):
                new_def = new_def + ' ' + word.text
        return new_def
    
    new_e1_def = definition_tokens(nlp(e1_def))
    new_e2_def = definition_tokens(nlp(e2_def))
                  
    return pd.Series([return_sentence[1:], 
                      return_sentence[1:].strip() +' where '+e1+' is the '+new_e1_def.strip()+' and '+e2+' is the '+new_e2_def.strip()],
                     index=['trimmed_sentence', 'trimmed_sentence_wdef'])

### Apply All Functions-Preprocess Data

In [13]:
def preprocess_data(df_original):
    df = df_original.copy()
    df[['e1','e2']] = df['sentence'].swifter.apply(extract_entities)
    df['warning_flags'] = 0
    df[['corrected_sentence','corrected_e1','trimmed_sentence']] = df.swifter.apply(lambda x: synset_check(x['e1'],x['sentence']),axis=1)
    df[['corrected_sentence','corrected_e2','trimmed_sentence']] = df.swifter.apply(lambda x: synset_check(x['e2'],x['sentence']),axis=1)
    df[['e1_synset','e1_definition','warning_flags']] = df.swifter.apply(lambda x: disambiguate_entity_in_sentence(x['corrected_sentence'],x['corrected_e1'],x['warning_flags']), axis=1)
    df[['e2_synset','e2_definition','warning_flags']] = df.swifter.apply(lambda x: disambiguate_entity_in_sentence(x['corrected_sentence'],x['corrected_e2'],x['warning_flags']), axis=1)
    df[['e1_dep_token', 'e1_prev_token', 'e1_post_token', 'e2_dep_token', 'e2_prev_token', 'e2_post_token']] = df.swifter.apply(lambda x: extract_features(x['corrected_e1'],x['corrected_e2'],x['corrected_sentence']), axis=1)
    df[['trimmed_sentence', 'trimmed_sentence_wdef']] = df.swifter.apply(lambda x: tune_sentence(x['trimmed_sentence'],x['corrected_e1'],x['corrected_e2'],x['e1_definition'],x['e2_definition']),axis=1)

    return df

In [14]:
train_df = preprocess_data(train_df)
test_df = preprocess_data(test_df)

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Synset for "natron" NOT FOUND
Synset for "opioids" NOT FOUND
Synset for "joey" NOT FOUND


Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/8000 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Synset for "kimchi" NOT FOUND
Synset for "tempeh" NOT FOUND


Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Synset for "wiki" NOT FOUND
Synset for "prequels" NOT FOUND


Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/2717 [00:00<?, ?it/s]

# Modelling

### 1. Model Training & Validation

In [15]:
X = train_df.drop(columns=['relation']) 
y = train_df['relation']  

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.15, random_state=42)
X_train = X_train[X_train['warning_flags']==0]
X_train['relation'] = y_train
X_val['relation'] = y_val

In [16]:
class RelationExtraction(torch.nn.Module):
    def __init__(self, num_labels):
        super(RelationExtraction, self).__init__()

        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.dropout1 = torch.nn.Dropout(p=0.05)
        self.fc1 = torch.nn.Linear(self.bert.config.hidden_size, num_labels)
        

    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids, attention_mask)
        
        last_hidden_state = outputs.last_hidden_state
        pooled_output = torch.mean(last_hidden_state, dim=1) 
        
        pooled_output = self.dropout1(pooled_output)
        output = self.fc1(pooled_output)
        logits = output
 
        return logits

In [17]:
def bert_model(train_df, test_df, epochs=5, batchsize=32, device=None):

    model = RelationExtraction(num_labels=len(train_df['relation'].unique())+1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) 
    criterion = torch.nn.CrossEntropyLoss()


    
    def tokenize_batch(df, batchsize=32):
        tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        tokenized_inputs = tokenizer(df['trimmed_sentence_wdef'].tolist(),
                                     df['corrected_e1'].tolist(),
                                     df['corrected_e2'].tolist(),
                                     df['e1_dep_token'].tolist(),
                                     df['e2_dep_token'].tolist(),
                                     padding=True, 
                                     truncation=True,
                                     return_tensors='pt')
        labels =  torch.tensor(df['relation'].tolist())

        
        dataset = TensorDataset(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'], labels)
        loader = DataLoader(dataset, batch_size=batchsize)
        
        return loader
    
    def train_validate(model, optimizer, criterion ,epochs, device):
        
        torch.manual_seed(42)
        
        for epoch in range(epochs):
            print('********EPOCH', epoch+1, 'TRAINING********')
            model.train()
            total_loss = 0

            for batch in tqdm(tokenize_batch(train_df), total=int(len(train_df)/batchsize)):

                input_ids, attention_mask, labels = batch

                if device:  
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    labels = torch.tensor(labels).to(device)

                outputs = model(input_ids, attention_mask)
                logits = outputs

                loss = criterion(logits, labels)
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            # Print average loss per epoch
            avg_loss = total_loss / len(train_df)
            print(f"Epoch {epoch+1}/{epochs} | Average Loss: {avg_loss:.4f}")


            model.eval()
            total_eval_loss = 0
            total_correct = 0
            pred_list = []
            label_list = []
            
            print('********EPOCH', epoch+1, 'VALIDATION********')

            for batch in tqdm(tokenize_batch(test_df), total=int(len(test_df)/batchsize)):

                input_ids, attention_mask, labels = batch

                with torch.no_grad():
                    outputs = model(input_ids, attention_mask)
                    logits = outputs

                    # Calculate predictions
                    predictions = torch.argmax(logits, dim=-1)
                    if epochs == epoch + 1:
                        pred_list.extend(np.array(predictions))
                        label_list.extend(np.array(labels))
                    

                    # Calculate accuracy
                    total_correct += (predictions == labels).sum().item()

                loss = criterion(logits, labels)
                total_eval_loss += loss.item()

            # Print evaluation metrics
            eval_loss = total_eval_loss / len(test_df)
            accuracy = total_correct / len(test_df) * 100
            print(f"Epoch {epoch+1}/{epochs} | Eval Loss: {eval_loss:.4f} | Accuracy: {accuracy:.2f}%")
            
        return pred_list, label_list
            
    pred_list, label_list = train_validate(model, optimizer, criterion ,epochs, device)
    
    return model, pred_list, label_list

In [18]:
model, y_pred, y = bert_model(X_train, X_val)

********EPOCH 1 TRAINING********


213it [04:57,  1.40s/it]                                                        


Epoch 1/5 | Average Loss: 0.0577
********EPOCH 1 VALIDATION********


38it [00:12,  2.96it/s]                                                         


Epoch 1/5 | Eval Loss: 0.0349 | Accuracy: 67.50%
********EPOCH 2 TRAINING********


213it [05:04,  1.43s/it]                                                        


Epoch 2/5 | Average Loss: 0.0276
********EPOCH 2 VALIDATION********


38it [00:12,  3.06it/s]                                                         


Epoch 2/5 | Eval Loss: 0.0272 | Accuracy: 72.33%
********EPOCH 3 TRAINING********


213it [05:01,  1.42s/it]                                                        


Epoch 3/5 | Average Loss: 0.0179
********EPOCH 3 VALIDATION********


38it [00:13,  2.89it/s]                                                         


Epoch 3/5 | Eval Loss: 0.0254 | Accuracy: 75.58%
********EPOCH 4 TRAINING********


213it [05:07,  1.45s/it]                                                        


Epoch 4/5 | Average Loss: 0.0116
********EPOCH 4 VALIDATION********


38it [00:12,  3.00it/s]                                                         


Epoch 4/5 | Eval Loss: 0.0257 | Accuracy: 76.50%
********EPOCH 5 TRAINING********


213it [04:57,  1.40s/it]                                                        


Epoch 5/5 | Average Loss: 0.0071
********EPOCH 5 VALIDATION********


38it [00:12,  2.96it/s]                                                         


Epoch 5/5 | Eval Loss: 0.0274 | Accuracy: 78.08%


In [19]:
print(classification_report(y, y_pred))

              precision    recall  f1-score   support

           0       0.78      0.88      0.83        49
           1       0.93      0.83      0.88       107
           2       0.79      0.86      0.82        73
           3       0.76      0.73      0.74        70
           4       0.79      0.84      0.81        62
           5       0.90      0.87      0.88        30
           6       0.82      0.82      0.82       119
           8       0.73      0.72      0.73        86
           9       0.82      0.64      0.72        22
          10       0.78      0.64      0.70        11
          11       0.87      0.86      0.86        63
          12       1.00      0.64      0.78        14
          13       0.86      0.85      0.86        95
          14       0.86      0.83      0.84        78
          15       0.81      0.81      0.81        21
          16       0.73      0.80      0.77        45
          17       0.88      0.78      0.82        63
          18       0.58    

In [25]:
print (len(train_df[train_df['relation']==7]))

1


### Analysis

### Test Case:

In [20]:
def test_model(test_df, model):
    
    def tokenize_batch(df, batchsize=32):
        tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        tokenized_inputs = tokenizer(df['trimmed_sentence_wdef'].tolist(),
                                     df['corrected_e1'].tolist(),
                                     df['corrected_e2'].tolist(),
                                     df['e1_dep_token'].tolist(),
                                     df['e2_dep_token'].tolist(),
                                     padding=True, 
                                     truncation=True,
                                     return_tensors='pt')
        labels =  torch.tensor(df['relation'].tolist())

        
        dataset = TensorDataset(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'], labels)
        loader = DataLoader(dataset, batch_size=batchsize)
        
        return loader
    
    pred_list = []
    label_list = []
    batchsize = 32        
    for batch in tqdm(tokenize_batch(test_df), total=int(len(test_df)/batchsize)):

        input_ids, attention_mask, labels = batch

        with torch.no_grad():
            outputs = model(input_ids, attention_mask)

            predictions = torch.argmax(outputs, dim=-1)

            pred_list.extend(np.array(predictions))
            label_list.extend(np.array(labels))
    
    return pred_list, label_list

In [21]:
y_pred, y = test_model(test_df, model)
print(classification_report(y, y_pred))

85it [00:37,  2.25it/s]                                                         

              precision    recall  f1-score   support

           0       0.85      0.90      0.87       134
           1       0.88      0.86      0.87       194
           2       0.83      0.82      0.83       162
           3       0.76      0.70      0.73       150
           4       0.83      0.88      0.86       153
           5       0.71      0.77      0.74        39
           6       0.79      0.87      0.83       291
           7       0.00      0.00      0.00         1
           8       0.84      0.73      0.78       211
           9       0.88      0.79      0.83        47
          10       0.88      0.64      0.74        22
          11       0.75      0.63      0.69       134
          12       0.73      0.59      0.66        32
          13       0.89      0.82      0.85       201
          14       0.83      0.84      0.83       210
          15       0.81      0.75      0.78        51
          16       0.82      0.85      0.84       108
          17       0.76    




# Inference Mode

In [22]:
relation_name = {
0: 'Cause-Effect(e1,e2)',
1: 'Cause-Effect(e2,e1)',
2: 'Component-Whole(e1,e2)',
3: 'Component-Whole(e2,e1)',
4: 'Content-Container(e1,e2)',
5: 'Content-Container(e2,e1)',
6: 'Entity-Destination(e1,e2)',
7: 'Entity-Destination(e2,e1)',
8: 'Entity-Origin(e1,e2)',
9: 'Entity-Origin(e2,e1)',
10: 'Instrument-Agency(e1,e2)', 
11: 'Instrument-Agency(e2,e1)',
12: 'Member-Collection(e1,e2)',
13: 'Member-Collection(e2,e1)',
14: 'Message-Topic(e1,e2)',
15: 'Message-Topic(e2,e1)',
16: 'Product-Producer(e1,e2)',
17: 'Product-Producer(e2,e1)',
18: 'Other'
}

display(pd.DataFrame([(key, value) for key, value in relation_name.items()], columns=["relation", "relation type"]))

def inference_mode(sentence):
    if type(sentence)==str:
        sentence = [sentence]
    df = pd.DataFrame(sentence, columns = ['sentence'])
    df = preprocess_data(df)
    
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    tokenized_inputs = tokenizer(df['corrected_sentence'].tolist(),
                                     df['corrected_e1'].tolist(),
                                     df['corrected_e2'].tolist(),
                                     df['e1_dep_token'].tolist(),
                                     df['e2_dep_token'].tolist(),
                                     padding=True, 
                                     truncation=True,
                                     return_tensors='pt')
    
    output = model(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'])

    predicted_label = torch.argmax(output, dim=-1)
    df['predicted_relation'] = predicted_label
    df['predicted_relation_type'] = df['predicted_relation'].map(relation_name)
    return df[['sentence','predicted_relation_type']]
    


Unnamed: 0,relation,relation type
0,0,"Cause-Effect(e1,e2)"
1,1,"Cause-Effect(e2,e1)"
2,2,"Component-Whole(e1,e2)"
3,3,"Component-Whole(e2,e1)"
4,4,"Content-Container(e1,e2)"
5,5,"Content-Container(e2,e1)"
6,6,"Entity-Destination(e1,e2)"
7,7,"Entity-Destination(e2,e1)"
8,8,"Entity-Origin(e1,e2)"
9,9,"Entity-Origin(e2,e1)"


### Try Your Sentence

In [24]:
display (inference_mode('<e1>Apple</e1> produces <e2>Iphone</e2>.'))

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,sentence,predicted_relation_type
0,<e1>Apple</e1> produces <e2>Iphone</e2>.,"Product-Producer(e2,e1)"
