In [38]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast  # For mixed precision training
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
import sentencepiece as spm 
import lightning as L
from lightning.pytorch.tuner import Tuner
from torchmetrics.classification import MulticlassF1Score
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, accuracy_score, classification_report
import pickle
from datetime import datetime

In [2]:
data = pd.read_csv("./data/lexeme_data_with_forms_3.csv")

In [3]:
data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms,lexeme_and_forms,bpe
0,taibsted,sms,V,taaibâst,V_MAINSTED,taaibâst taaibast taaibstem taibstem taaibsti ...,taibsted taaibstâkaz taibstâkaz,▁tai b sted ▁taa ib stâkaz ▁tai b stâkaz
1,ääʹll,sms,N,ää%{ʹØ%}ll,N_SAAQMM,ääʹlest älla ääʹli aaʹli ääʹll ääʹl ääʹlstes ä...,ääʹll ällsan,▁ää ʹ ll ▁ä llsan
2,njââʹllvaaldõs,sms,N,njââʹllvaaldõ%^1VOW%{ʹØ%}s,N_SAJOS,njââʹllvaaldõõzzâst njââʹllvaaldõʹsse njââʹllv...,njââʹllvaaldõs njââʹllvaaldõssʼsan,▁njââ ʹ ll vaald õs ▁njââ ʹ ll vaa ldõss ʼ san
3,läukkad,sms,V,lä%^1VOWukk,V_LAEULLAD,lääuk läukk laukkum lääukai läukkad lääukam la...,läukkad läukkaz,▁läu kkad ▁läu kkaz
4,laukkõõllâd,sms,V,laukkõõ%{ʹØ%}ll,V_LAUKKOOLLYD,laukkõõl laukkââll laukkõʹllem laukkõõli laukk...,laukkõõllâd laukkâllaz,▁lau kk õõllâd ▁lau kk âllaz
...,...,...,...,...,...,...,...,...
24034,jieʹllidåhttar,sms,N,jieʹlli#dåhttar,N_AANAR,,jieʹllidåhttar,▁jie ʹ ll id åhtt ar
24035,nuõrrǥaž,sms,N,nuõrrǥ,N_MEERSAZH,,nuõrrǥaž,▁nuõ rr ǥ až
24036,looǥǥâlm,sms,N,looǥǥâlm,N_COOGGYLM,,looǥǥâlm,▁looǥǥ âlm
24037,paneelsaǥstõõllmõš,sms,N,panẹẹl#saǥ»stõõll»mõ%^1VOW%{ʹØ%}š,N_SAJOS,,paneelsaǥstõõllmõš,▁pan ee l sa ǥ stõõllmõš


In [4]:
clean_data = data.copy()
clean_data.dropna(subset=["forms"], inplace=True)
clean_data.drop(columns=["lexeme_and_forms", "bpe"], inplace=True)

In [5]:
clean_data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms
0,taibsted,sms,V,taaibâst,V_MAINSTED,taaibâst taaibast taaibstem taibstem taaibsti ...
1,ääʹll,sms,N,ää%{ʹØ%}ll,N_SAAQMM,ääʹlest älla ääʹli aaʹli ääʹll ääʹl ääʹlstes ä...
2,njââʹllvaaldõs,sms,N,njââʹllvaaldõ%^1VOW%{ʹØ%}s,N_SAJOS,njââʹllvaaldõõzzâst njââʹllvaaldõʹsse njââʹllv...
3,läukkad,sms,V,lä%^1VOWukk,V_LAEULLAD,lääuk läukk laukkum lääukai läukkad lääukam la...
4,laukkõõllâd,sms,V,laukkõõ%{ʹØ%}ll,V_LAUKKOOLLYD,laukkõõl laukkââll laukkõʹllem laukkõõli laukk...
...,...,...,...,...,...,...
24027,riikkvääraiministeria,sms,N,riikk#väärai#ministeria,N_BIOLOGIA,riikkvääraiministeriast riikkvääraiministeriaa...
24028,nõõmuʹvddem-moʹlidva,sms,N,nõõm#uʹvddem-#moʹlidva,N_BUKVA,nõõmuʹvddem-moʹlidvast nõõmuʹvddem-moʹlidvaaʹj...
24029,teevvamhåidd,sms,N,teevvam#hå%^1VOWidd,N_AELDD,teevvamhååidast teevvamhoiddu teevvamhååidai t...
24030,koomačkåhtt,sms,N,koomač#kå%^1VOWhtt,N_KOAHTT,koomačkååutast koomačkohttu koomačkååutai koom...


In [6]:
len(clean_data["contlex"].unique())

939

In [7]:
clean_data["label"] = clean_data["contlex"].str.split('_').str[:2].str.join('_')

In [8]:
clean_data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms,label
0,taibsted,sms,V,taaibâst,V_MAINSTED,taaibâst taaibast taaibstem taibstem taaibsti ...,V_MAINSTED
1,ääʹll,sms,N,ää%{ʹØ%}ll,N_SAAQMM,ääʹlest älla ääʹli aaʹli ääʹll ääʹl ääʹlstes ä...,N_SAAQMM
2,njââʹllvaaldõs,sms,N,njââʹllvaaldõ%^1VOW%{ʹØ%}s,N_SAJOS,njââʹllvaaldõõzzâst njââʹllvaaldõʹsse njââʹllv...,N_SAJOS
3,läukkad,sms,V,lä%^1VOWukk,V_LAEULLAD,lääuk läukk laukkum lääukai läukkad lääukam la...,V_LAEULLAD
4,laukkõõllâd,sms,V,laukkõõ%{ʹØ%}ll,V_LAUKKOOLLYD,laukkõõl laukkââll laukkõʹllem laukkõõli laukk...,V_LAUKKOOLLYD
...,...,...,...,...,...,...,...
24027,riikkvääraiministeria,sms,N,riikk#väärai#ministeria,N_BIOLOGIA,riikkvääraiministeriast riikkvääraiministeriaa...,N_BIOLOGIA
24028,nõõmuʹvddem-moʹlidva,sms,N,nõõm#uʹvddem-#moʹlidva,N_BUKVA,nõõmuʹvddem-moʹlidvast nõõmuʹvddem-moʹlidvaaʹj...,N_BUKVA
24029,teevvamhåidd,sms,N,teevvam#hå%^1VOWidd,N_AELDD,teevvamhååidast teevvamhoiddu teevvamhååidai t...,N_AELDD
24030,koomačkåhtt,sms,N,koomač#kå%^1VOWhtt,N_KOAHTT,koomačkååutast koomačkohttu koomačkååutai koom...,N_KOAHTT


In [9]:
label_counts = clean_data["label"].value_counts()
label_counts

label
N_SAJOS           5969
N_MAINSTUMMUSH    1557
N_MUORR            792
N_AANAR            616
V_LAUKKOOLLYD      609
                  ... 
N_CHEE               1
IV_TEYPSTED          1
IV_LEEQD             1
N_PUUQTTES           1
N_KARIES             1
Name: count, Length: 514, dtype: int64

In [10]:
counts_of_counts = label_counts.value_counts()
counts_of_counts

count
1      104
2       57
3       35
5       32
4       31
      ... 
136      1
144      1
149      1
151      1
63       1
Name: count, Length: 103, dtype: int64

In [11]:
def remove_rare_combinations(data, contlex_column='contlex', min_samples=20):
    label_counts = data[contlex_column].value_counts()
    
    frequent_labels = label_counts[label_counts >= min_samples].index

    filtered_data = data[data[contlex_column].isin(frequent_labels)]
    
    return filtered_data.reset_index(drop=True)


In [12]:
clean_data = remove_rare_combinations(clean_data, contlex_column='label', min_samples=50)

clean_data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms,label
0,ääʹll,sms,N,ää%{ʹØ%}ll,N_SAAQMM,ääʹlest älla ääʹli aaʹli ääʹll ääʹl ääʹlstes ä...,N_SAAQMM
1,njââʹllvaaldõs,sms,N,njââʹllvaaldõ%^1VOW%{ʹØ%}s,N_SAJOS,njââʹllvaaldõõzzâst njââʹllvaaldõʹsse njââʹllv...,N_SAJOS
2,laukkõõllâd,sms,V,laukkõõ%{ʹØ%}ll,V_LAUKKOOLLYD,laukkõõl laukkââll laukkõʹllem laukkõõli laukk...,V_LAUKKOOLLYD
3,hiâvtõõttâd,sms,V,hiâvtõõ%{ʹØ%}tt,V_LAUKKOOLLYD,hiâvtõõđ hiâvtââtt hiâvtõʹttem hiâvtõõđi hiâvt...,V_LAUKKOOLLYD
4,hiâvtõõttâd,sms,V,hiõvtõõ%{ʹØ%}tt,V_LAUKKOOLLYD_ERRORTH,hiâvtõõđ hiâvtââtt hiâvtõʹttem hiâvtõõđi hiâvt...,V_LAUKKOOLLYD
...,...,...,...,...,...,...,...
18572,puåtkknõddâd,sms,V,puåtkknõ%^1VOW%{ʹØ%}dd,V_ROVVYD,puåtkknõõdd puåtkknâdd puåtkknõʹddem puåtkknõõ...,V_ROVVYD
18573,ruõššlaž,sms,N,ruõšˈšl,N_MEERSAZH_SEMHUM,ruõššlast ruõššlõʹžže ruõššlai ruõššlaž ruõššl...,N_MEERSAZH
18574,riikkvääraiministeria,sms,N,riikk#väärai#ministeria,N_BIOLOGIA,riikkvääraiministeriast riikkvääraiministeriaa...,N_BIOLOGIA
18575,teevvamhåidd,sms,N,teevvam#hå%^1VOWidd,N_AELDD,teevvamhååidast teevvamhoiddu teevvamhååidai t...,N_AELDD


In [13]:
clean_data["label"].value_counts()

label
N_SAJOS           5969
N_MAINSTUMMUSH    1557
N_MUORR            792
N_AANAR            616
V_LAUKKOOLLYD      609
                  ... 
V_ROVVYD            54
N_TAQHTT            52
N_KHEQRJJ           52
V_JEAELSTED         52
N_SIYKKK            52
Name: count, Length: 73, dtype: int64

In [14]:
len(clean_data["label"].unique())

73

In [15]:
clean_data.groupby("pos")["label"].nunique().to_dict()

{'N': 52, 'V': 21}

In [16]:
sp = spm.SentencePieceProcessor(model_file='./skolt_bpe_3.model')

In [17]:
class LabelEncoderManager:
    def __init__(self):
        self.pos_encoder = LabelEncoder()
        self.contlex_encoders = {}  # Dictionary to hold a LabelEncoder for each POS class

    def fit(self, pos_labels, contlex_labels):
        """
        Fit the POS encoder and the contlex encoders based on the corresponding POS.
        
        pos_labels: List or array of POS labels.
        contlex_labels: List or array of contlex labels.
        """
        # Fit the POS encoder
        self.pos_encoder.fit(pos_labels)

        # Initialize and fit a LabelEncoder for each unique POS class
        unique_pos = set(pos_labels)
        for pos_class in unique_pos:
            contlex_for_pos = [contlex_labels[i] for i in range(len(pos_labels)) if pos_labels[i] == pos_class]
            encoder = LabelEncoder()
            encoder.fit(contlex_for_pos)
            self.contlex_encoders[pos_class] = encoder

    def transform_pos(self, pos_labels):
        """Transform POS labels using the POS encoder."""
        return self.pos_encoder.transform(pos_labels)

    def inverse_transform_pos(self, encoded_pos_labels):
        """Inverse transform encoded POS labels using the POS encoder."""
        return self.pos_encoder.inverse_transform(encoded_pos_labels)

    def transform_contlex(self, pos_labels, contlex_labels):
        """
        Transform contlex labels using the corresponding encoder for each POS class.
        
        pos_labels: List or array of POS labels (used to select the corresponding contlex encoder).
        contlex_labels: List or array of contlex labels to be encoded.
        """
        encoded_contlex_labels = []
        for pos, contlex in zip(pos_labels, contlex_labels):
            encoder = self.contlex_encoders.get(pos)
            if encoder is not None:
                encoded_contlex_labels.append(encoder.transform([contlex])[0])
            else:
                raise ValueError(f"No contlex encoder found for POS class: {pos}")
        return encoded_contlex_labels

    def inverse_transform_contlex(self, pos_labels, encoded_contlex_labels):
        """
        Inverse transform contlex labels using the corresponding encoder for each POS class.
        
        pos_labels: List or array of POS labels (used to select the corresponding contlex encoder).
        encoded_contlex_labels: List or array of encoded contlex labels to be inverse transformed.
        """
        contlex_labels = []
        for pos, enc_contlex in zip(pos_labels, encoded_contlex_labels):
            encoder = self.contlex_encoders.get(pos)
            if encoder is not None:
                contlex_labels.append(encoder.inverse_transform([enc_contlex])[0])
            else:
                raise ValueError(f"No contlex encoder found for POS class: {pos}")
        return contlex_labels

    def save_encoders(self, path):
        """
        Save the POS encoder and contlex encoders to a file using pickle.
        
        path: The file path where the encoders will be saved.
        """
        with open(path, 'wb') as f:
            pickle.dump({'pos_encoder': self.pos_encoder, 'contlex_encoders': self.contlex_encoders}, f)

    def load_encoders(self, path):
        """
        Load the POS encoder and contlex encoders from a file using pickle.
        
        path: The file path where the encoders are stored.
        """
        with open(path, 'rb') as f:
            data = pickle.load(f)
            self.pos_encoder = data['pos_encoder']
            self.contlex_encoders = data['contlex_encoders']

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
class CustomDataset(Dataset):
    def __init__(self, X, pos_labels, contlex_labels):
        self.X = X
        self.pos_labels = pos_labels
        self.contlex_labels = contlex_labels

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X_item = torch.tensor(self.X[idx], dtype=torch.long)
        pos_label = torch.tensor(self.pos_labels[idx], dtype=torch.long)
        contlex_label = torch.tensor(self.contlex_labels[idx], dtype=torch.long)
        return X_item, pos_label, contlex_label

In [20]:
clean_data["X"] = clean_data["lexeme"] + " " + clean_data["forms"]

In [21]:
training_data, test_data = train_test_split(clean_data, test_size=0.1, random_state=42, stratify=clean_data["label"])

In [22]:
training_data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms,label,X
6909,põõrǥâståimm,sms,N,põõrǥâs#tå%^1VOWimm,N_AELDD,põõrǥâstååimast põõrǥâstoimmu põõrǥâstååimai p...,N_AELDD,põõrǥâståimm põõrǥâstååimast põõrǥâstoimmu põõ...
7341,ǩeʹrjjvacc,sms,N,ǩeʹrjj#va%^1VOW%{ʹØ%}cc,N_PAPP,ǩeʹrjjvaaccâst ǩeʹrjjvaʹcce ǩeʹrjjvaacci ǩeʹrj...,N_PAPP,ǩeʹrjjvacc ǩeʹrjjvaaccâst ǩeʹrjjvaʹcce ǩeʹrjjv...
8005,tieʹddtummuš,sms,N,tieʹddtummuš,N_MAINSTUMMUSH,tieʹddtummšest tieʹddtummša tieʹddtummši tieʹd...,N_MAINSTUMMUSH,tieʹddtummuš tieʹddtummšest tieʹddtummša tieʹd...
1155,vuõjjšõõddmõš,sms,N,vuõjjšõõdd»mõ%^1VOW%{ʹØ%}š,N_SAJOS,vuõjjšõõddmuužžâst vuõjjšõõddmõõžžâst vuõjjšõõ...,N_SAJOS,vuõjjšõõddmõš vuõjjšõõddmuužžâst vuõjjšõõddmõõ...
9847,sǩiâŋkk,sms,N,#sǩiâ%{ʹØ%}ŋkk,N_MIYRKK,sǩiâŋkâst sǩieʹŋǩǩe sǩiâŋki sǩiâŋkk sǩiâŋk sǩi...,N_MIYRKK,sǩiâŋkk sǩiâŋkâst sǩieʹŋǩǩe sǩiâŋki sǩiâŋkk sǩ...
...,...,...,...,...,...,...,...,...
13105,vuõleed,sms,V,vuõl,V_SILTTEED,vuõlâd vuõlad vuõleem vuõlii vuõleed vuõlääm v...,V_SILTTEED,vuõleed vuõlâd vuõlad vuõleem vuõlii vuõleed v...
4110,sieʹǩǩporrmõš,sms,N,sieʹǩǩporr»mõ%^1VOW%{ʹØ%}š,N_SAJOS,sieʹǩǩporrmuužžâst sieʹǩǩporrmõõžžâst sieʹǩǩpo...,N_SAJOS,sieʹǩǩporrmõš sieʹǩǩporrmuužžâst sieʹǩǩporrmõõ...
2677,čårrmeäʹcc,sms,N,čårr#meä%{ʹØ%}cˈc,N_JEAQNNN,čårrmieʹccest čårrmeäcca čårrmieʹcci čårrmeäʹc...,N_JEAQNNN,čårrmeäʹcc čårrmieʹccest čårrmeäcca čårrmieʹcc...
15078,jiâkstõõttâd,sms,V,jiâˈkstõõ%{ʹØ%}tt,V_LAUKKOOLLYD,jiâkstõõđ jiâkstââtt jiâkstõʹttem jiâkstõõđi j...,V_LAUKKOOLLYD,jiâkstõõttâd jiâkstõõđ jiâkstââtt jiâkstõʹttem...


In [23]:
def tokenize_input(texts, sp_model):
    """
    Tokenize input text using a SentencePiece model.
    
    texts: List of input texts to be tokenized.
    sp_model: SentencePiece model to tokenize the texts.
    
    Returns:
    tokenized_texts: List of tokenized and padded input sequences.
    max_len: Maximum sequence length.
    vocab_size: Size of the vocabulary.
    """
    tokenized_texts = [sp_model.encode(text, out_type=int) for text in texts]
    max_len = max(len(x) for x in tokenized_texts)
    padded_texts = [x + [0] * (max_len - len(x)) for x in tokenized_texts]  # Padding with 0
    return padded_texts, max_len, sp_model.get_piece_size()

In [24]:
X, max_len, vocab_size = tokenize_input(training_data['X'].tolist(), sp)
pos_labels = training_data['pos'].tolist()
contlex_labels = training_data['label'].tolist()

In [25]:
encoder_manager = LabelEncoderManager()

encoder_manager.fit(pos_labels, contlex_labels)

encoded_pos = encoder_manager.transform_pos(pos_labels)
encoded_contlex = encoder_manager.transform_contlex(pos_labels, contlex_labels)

# encoder_manager.save_encoders('label_encoders_3.pkl')

encoder_manager.load_encoders('label_encoders_3.pkl')

In [26]:
contlex_output_map = training_data.groupby("pos")["label"].nunique().to_dict()
contlex_output_map = {encoder_manager.transform_pos([k])[0]: v for k, v in contlex_output_map.items()}
contlex_output_map

{0: 52, 1: 21}

In [27]:
X_train, X_val, pos_train, pos_val, contlex_train, contlex_val = train_test_split(X, encoded_pos, encoded_contlex, test_size=0.2, random_state=42, stratify=contlex_labels)

In [28]:
train_dataset = CustomDataset(X_train, pos_train, contlex_train)
val_dataset = CustomDataset(X_val, pos_val, contlex_val)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [29]:
class SharedEmbeddingTransformer(L.LightningModule):
    def __init__(
        self,
        pos_num_classes,
        contlex_output_map,
        vocab_size,
        embed_size=96,
        hidden_size=128,
        num_layers=2,
        nhead=4,
        dropout=0.1,
        learning_rate=1e-3,
        batch_size=32,
    ):
        super(SharedEmbeddingTransformer, self).__init__()
        self.save_hyperparameters()

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=hidden_size,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc_out_pos = nn.Linear(embed_size, pos_num_classes)

        self.fc_out_contlex = nn.ModuleDict(
            {
                str(pos_class): nn.Linear(embed_size, contlex_output_map[pos_class])
                for pos_class in contlex_output_map
            }
        )

        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.xavier_uniform_(self.fc_out_pos.weight)

        for pos_class in contlex_output_map:
            nn.init.xavier_uniform_(self.fc_out_contlex[str(pos_class)].weight)

        self.pos_weight = 1.0
        self.contlex_weight = 1.0

        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.contlex_output_map = contlex_output_map
        self.pos_class_sorted = list(sorted(contlex_output_map.keys()))
        self.contlex_class_max = max(contlex_output_map.values())

        self.pos_f1 = MulticlassF1Score(num_classes=pos_num_classes, average="weighted")
        self.contlex_f1 = MulticlassF1Score(num_classes=self.contlex_class_max, average="weighted")

    def forward(self, x, pos_target=None):
        x = self.embedding(x)
        transformer_out = self.transformer_encoder(x)

        transformer_last_hidden = transformer_out[:, -1, :]

        pos_output = self.fc_out_pos(transformer_last_hidden)

        pos_labels = pos_target if pos_target is not None else torch.argmax(pos_output, dim=1) 

        contlex_output = torch.zeros((x.size(0), self.contlex_class_max), device=self.device)

        for pos_class in self.pos_class_sorted:
            contlex_size = self.hparams.contlex_output_map[pos_class]
            indices = (pos_labels == int(pos_class)).nonzero(as_tuple=True)[0]

            if len(indices) > 0:
                fc_out_contlex = self.fc_out_contlex[str(pos_class)]
                contlex_out = fc_out_contlex(transformer_last_hidden[indices])
                contlex_output[indices, :contlex_size] = contlex_out

        return pos_output, contlex_output

    def custom_loss(self, pos_output, contlex_output, pos_target, contlex_target):
        pos_loss = nn.CrossEntropyLoss()(pos_output, pos_target)

        contlex_loss = nn.CrossEntropyLoss()(contlex_output, contlex_target)

        total_loss = self.pos_weight * pos_loss + self.contlex_weight * contlex_loss
        return total_loss, pos_loss, contlex_loss

    def training_step(self, batch, batch_idx):
        x, pos_y, contlex_y = batch
        pos_output, contlex_output = self(x, pos_target=pos_y)

        total_loss, pos_loss, contlex_loss = self.custom_loss(pos_output, contlex_output, pos_y, contlex_y)
        self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_pos_loss", pos_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_contlex_loss", contlex_loss, on_step=True, on_epoch=True, prog_bar=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        x, pos_y, contlex_y = batch
        pos_output, contlex_output = self(x, pos_target=pos_y)

        total_loss, pos_loss, contlex_loss = self.custom_loss(pos_output, contlex_output, pos_y, contlex_y)
        self.log("val_loss", total_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_pos_loss", pos_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_contlex_loss", contlex_loss, on_step=False, on_epoch=True, prog_bar=True)

        pos_preds = torch.argmax(pos_output, dim=1)
        pos_acc = (pos_preds == pos_y).float().mean()
        self.log("val_pos_acc", pos_acc, on_step=False, on_epoch=True, prog_bar=True)

        contlex_preds = torch.argmax(contlex_output, dim=1)                
        contlex_acc = (contlex_preds == contlex_y).float().mean()
        self.log(f"val_contlex_acc", contlex_acc, on_step=False, on_epoch=True, prog_bar=True)

        self.pos_f1(pos_preds, pos_y)
        self.log("val_pos_f1", self.pos_f1, on_step=False, on_epoch=True, prog_bar=True)

        self.contlex_f1(contlex_preds, contlex_y)
        self.log("val_contlex_f1", self.contlex_f1, on_step=False, on_epoch=True, prog_bar=True)

        self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], on_step=False, on_epoch=True, prog_bar=True)


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25)
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
        
    
    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)

In [None]:
# Hyperparameters
embed_size = 128
hidden_size = 512
num_layers = 3
nhead = 8
dropout = 0.2
epochs = 100
learning_rate = 0.003
batch_size = 512

pos_num_classes = len(set(pos_labels))

date_time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

model = SharedEmbeddingTransformer(
    pos_num_classes, 
    contlex_output_map, 
    vocab_size, 
    embed_size=embed_size,
    hidden_size=hidden_size, 
    num_layers=num_layers,
    nhead=nhead,
    dropout=dropout,
    learning_rate=learning_rate,
    batch_size=batch_size
)

print(model.hparams)

In [None]:
checkpoint_callback = ModelCheckpoint(
     monitor='val_loss',
     dirpath=f"./model/{date_time_str}/",
     filename='model-{epoch:02d}-{val_loss:.2f}-{val_contlex_f1:.2f}',
     mode="min",
     enable_version_counter=True,
    #  save_last=True,
)

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min")


trainer = L.Trainer(
    min_epochs=100,
    max_epochs=epochs,
    accumulate_grad_batches=4,
    callbacks=[
        # early_stop_callback, 
        StochasticWeightAveraging(swa_lrs=1e-2), checkpoint_callback]
)

# tuner = Tuner(trainer)
# tuner.scale_batch_size(model, mode="power")
# lr_finder = tuner.lr_find(model)

In [None]:
# fig = lr_finder.plot(suggest=True)
# fig.show()

In [None]:
# # TO TRAIN, uncomment
# trainer.fit(model, train_dataloader, val_dataloader)

In [30]:
model= SharedEmbeddingTransformer.load_from_checkpoint(f"./model/BEST/final_model.ckpt")

In [31]:
test_data

Unnamed: 0,lexeme,language,pos,stem_text,contlex,forms,label,X
12779,nõmmpeiʹvv,sms,N,nõmm#pẹ%^1VOWi%{ʹØ%}vv,N_PEIQVV,nõmmpeeiʹvest nõmmpeivva nõmmpeeiʹvi nõmmpeiʹv...,N_PEIQVV,nõmmpeiʹvv nõmmpeeiʹvest nõmmpeivva nõmmpeeiʹv...
4633,mååustpeʹrrjummuš,sms,N,mååustpeʹrr»jummuš,N_MAINSTUMMUSH,mååustpeʹrrjummšest mååustpeʹrrjummša mååustpe...,N_MAINSTUMMUSH,mååustpeʹrrjummuš mååustpeʹrrjummšest mååustpe...
13527,vuetǩǩummuš,sms,N,vuetǩǩummuš,N_MAINSTUMMUSH,vuetǩǩummšest vuetǩǩummša vuetǩǩummši vuetǩǩum...,N_MAINSTUMMUSH,vuetǩǩummuš vuetǩǩummšest vuetǩǩummša vuetǩǩum...
16047,säähharpõõst,sms,N,säähhar#põ%^1VOW%{ʹØ%}stt,N_ALGG_PL,säähharpõõsti säähharpõsttân,N_ALGG,säähharpõõst säähharpõõsti säähharpõsttân
3795,čåuʹjjmoiʹvvjummuš,sms,N,čåuʹjjmoiʹvv»jummuš,N_MAINSTUMMUSH,čåuʹjjmoiʹvvjummšest čåuʹjjmoiʹvvjummša čåuʹjj...,N_MAINSTUMMUSH,čåuʹjjmoiʹvvjummuš čåuʹjjmoiʹvvjummšest čåuʹjj...
...,...,...,...,...,...,...,...,...
17588,ämmat-tuʹtǩǩõs,sms,N,ämmat#tuʹtǩǩõ%^1VOW%{ʹØ%}s,N_SAJOS,ämmat-tuʹtǩǩõõzzâst ämmat-tuʹtǩǩõʹsse ämmat-tu...,N_SAJOS,ämmat-tuʹtǩǩõs ämmat-tuʹtǩǩõõzzâst ämmat-tuʹtǩ...
12648,njiõkknjâʹstted,sms,V,njiõkknjâ%^1VOW%{ʹØ%}stt,V_CEQPCCED,njiõkknjââʹst njiõkknjâstt njiõkknjõʹsttem nji...,V_CEQPCCED,njiõkknjâʹstted njiõkknjââʹst njiõkknjâstt nji...
15104,rieʹǧǧtummuš,sms,N,rieʹǧǧtummuš,N_MAINSTUMMUSH,rieʹǧǧtummšest rieʹǧǧtummša rieʹǧǧtummši rieʹǧ...,N_MAINSTUMMUSH,rieʹǧǧtummuš rieʹǧǧtummšest rieʹǧǧtummša rieʹǧ...
102,kuärččjed,sms,V,kuärčč,V_KUYDHDHDHJED,kuårčču kuärččai kuärččjem kuärččji kuärččjed ...,V_KUYDHDHDHJED,kuärččjed kuårčču kuärččai kuärččjem kuärččji ...


In [32]:
print(model.hparams)
model.eval()
model = model.to(device)

"batch_size":         512
"contlex_output_map": {0: 52, 1: 21}
"dropout":            0.2
"embed_size":         128
"hidden_size":        512
"learning_rate":      0.003
"nhead":              8
"num_layers":         3
"pos_num_classes":    2
"vocab_size":         2000


In [33]:
for i, row in test_data.iterrows():
    x_tokenized = sp.encode(row["X"], out_type=int)
    x = torch.tensor([x_tokenized], dtype=torch.long).to(device)
    pos_output, _ = model(x)
    pos_pred = torch.argmax(pos_output, dim=1)

    pos_target_main = encoder_manager.transform_pos([row["pos"]])
    pos_target = torch.tensor(pos_target_main, dtype=torch.long).to(device)

    _, contlex_output = model(x, pos_target=pos_target)
    contlex_pred = torch.argmax(contlex_output, dim=1)
    pos_label = encoder_manager.inverse_transform_pos([pos_pred.item()])[0]
    contlex_label = encoder_manager.inverse_transform_contlex([row["pos"]], [contlex_pred.item()])[0]
    test_data.at[i, f"pos_pred"] = pos_label
    test_data.at[i, f"contlex_pred"] = contlex_label

print(classification_report(test_data["pos"], test_data[f"pos_pred"]))
print(classification_report(test_data["label"], test_data[f"contlex_pred"]))


              precision    recall  f1-score   support

           N       1.00      1.00      1.00      1520
           V       1.00      1.00      1.00       338

    accuracy                           1.00      1858
   macro avg       1.00      1.00      1.00      1858
weighted avg       1.00      1.00      1.00      1858

                precision    recall  f1-score   support

IV_LAUKKOOLLYD       0.69      0.88      0.77        33
     N_AACCIKH       1.00      1.00      1.00        18
       N_AANAR       0.98      0.95      0.97        62
       N_AELDD       0.70      0.88      0.78        24
        N_ALGG       0.88      0.90      0.89        51
        N_ATOM       0.25      0.17      0.20         6
        N_AUTT       0.61      0.85      0.71        13
    N_BIOLOGIA       0.94      1.00      0.97        15
     N_CHAAQCC       0.94      1.00      0.97        16
    N_CHUAQRVV       1.00      0.71      0.83         7
    N_CHUOSHKK       0.83      0.71      0.77         7


In [36]:
max_len_of_X_words = test_data["X"].apply(lambda x: x.count(" ")).max()
max_len_of_X_words

52

In [40]:
for j in range(1, max_len_of_X_words + 1):
    print(f"Max number of words: {j}")

    for i, row in test_data.iterrows():
        X_with_max_j_words = " ".join(row["X"].split()[:j])
        x_tokenized = sp.encode(X_with_max_j_words, out_type=int)
        x = torch.tensor([x_tokenized], dtype=torch.long).to(device)
        pos_output, _ = model(x)
        pos_pred = torch.argmax(pos_output, dim=1)

        pos_target_main = encoder_manager.transform_pos([row["pos"]])
        pos_target = torch.tensor(pos_target_main, dtype=torch.long).to(device)

        _, contlex_output = model(x, pos_target=pos_target)
        contlex_pred = torch.argmax(contlex_output, dim=1)
        pos_label = encoder_manager.inverse_transform_pos([pos_pred.item()])[0]
        contlex_label = encoder_manager.inverse_transform_contlex([row["pos"]], [contlex_pred.item()])[0]
        test_data.at[i, f"pos_pred"] = pos_label
        test_data.at[i, f"contlex_pred"] = contlex_label

    print("POS Accuracy:", accuracy_score(test_data["pos"], test_data[f"pos_pred"]))
    print("Contlex Accuracy:", accuracy_score(test_data["label"], test_data[f"contlex_pred"]))
    print()


Max number of words: 1
POS Accuracy: 0.9736275565123789
Contlex Accuracy: 0.36544671689989233

Max number of words: 2
POS Accuracy: 0.9757804090419806
Contlex Accuracy: 0.569967707212056

Max number of words: 3
POS Accuracy: 0.9644779332615716
Contlex Accuracy: 0.6178686759956943

Max number of words: 4
POS Accuracy: 0.9720129171151776
Contlex Accuracy: 0.6792249730893434

Max number of words: 5
POS Accuracy: 0.9741657696447793
Contlex Accuracy: 0.689989235737352

Max number of words: 6
POS Accuracy: 0.9752421959095802
Contlex Accuracy: 0.6926803013993541

Max number of words: 7
POS Accuracy: 0.9790096878363832
Contlex Accuracy: 0.7072120559741658

Max number of words: 8
POS Accuracy: 0.988697524219591
Contlex Accuracy: 0.7561894510226049

Max number of words: 9
POS Accuracy: 0.9827771797631862
Contlex Accuracy: 0.7734122712594187

Max number of words: 10
POS Accuracy: 0.9860064585575888
Contlex Accuracy: 0.778794402583423

Max number of words: 11
POS Accuracy: 0.9881593110871906
Contl