In [23]:
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
COMET_API_KEY = secrets.get_secret('comet_api_key')

In [24]:
from pathlib import Path
import os
import random


os.environ["TOKENIZERS_PARALLELISM"] = "false"

FLAGS = {
    # batch size used in model
    'batch_size': 12,
    # batch size for metrics calculaation
    'batch_size_for_knn': 256,
    # max epochs for training
    'max_epochs': 3,
    # to set manual seed, random if None
    'SEED': None,
    # one of {'train', 'valid', 'test'}
    'stage': 'test',
    # to use checkpoint, set None to use defeault CodeBERT
    'CHECKPOINT_PATH': '/kaggle/input/final-weights/semihard+hard_checkpoints/epoch=2-step=24471.ckpt',
    # name for logger experiment
    'EXPERIMENT_NAME': "semihard+hard_test",
    # name for logger project
    'PROJECT_NAME': 'typebert4py',
    # to follow more simple precossing steps
    'preprocess_as_hityper': False,
    # to do fast run with 10% of data
    'trial_run': False
}

MODEL_PARAMS = {
    # margin value
    'margin': 2,
    # number of infered types
    'number_of_neighbors_to_find': 10,
    # learning rate
    'lr': 1e-5,
    # either HARD or SEMI_HARD
    'loss_type': 'HARD',
    # to use same parameters as in checkpoint
    'keep_same': True,
    # embedding dimensionality the model produces. CodeBERT specific
    'emb_dim': 768,
    # number of tokens model can accept. CodeBERT specific
    'model_max_length': 512
}


# dataset directory
DATASET_DIR = Path('/kaggle/input/manytypes4pyaggregated')
TRAIN_PATH = DATASET_DIR / 'train.csv'
TEST_PATH = DATASET_DIR / 'test.csv'
VALID_PATH = DATASET_DIR / 'valid.csv'

if not FLAGS['SEED']:
    FLAGS['SEED'] = random.randint(1, 1000)

In [25]:
!pip install pytorch-metric-learning -q
!pip install comet-ml -q
!pip install faiss-gpu -q

[0m

In [26]:
import torch
import torch.nn as nn

from pytorch_metric_learning import miners, losses
from pytorch_metric_learning.utils import accuracy_calculator

from tokenizers import processors
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset

import faiss
import faiss.contrib.torch_utils


import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping, ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import CometLogger


from torchmetrics import MeanMetric

try:
   import ipywidgets
   from tqdm.auto import tqdm
except ImportError as e:
   from tqdm import tqdm

from collections import defaultdict
from functools import partial
from ast import literal_eval
from enum import Enum, auto
import gc
import math
from collections import Counter
import regex

In [27]:
pl.seed_everything(FLAGS['SEED'])

983

In [28]:
# calculate the number of types before preprocessing


# data = load_dataset("csv", data_files={"train": str(TRAIN_PATH), "test": str(TEST_PATH), "valid": str(VALID_PATH)})
# train_data = data['train'] #.train_test_split(test_size=1000, shuffle=False)['test']
# valid_data = data['valid'] #.train_test_split(test_size=1000, shuffle=False)['test']
# test_data = data['test'] #.train_test_split(test_size=1000, shuffle=False)['test']

# annotations_num_train = 0
# types_train = set()
# for ex in tqdm(train_data):
#     typed_seq = literal_eval(ex['typed_seq'])
#     annotations_num_train += sum(map(lambda type_: type_ != '0', typed_seq))
#     types_train.update(typed_seq)

# annotations_num_valid = 0
# types_valid = set()
# for ex in tqdm(valid_data):
#     typed_seq = literal_eval(ex['typed_seq'])
#     annotations_num_valid += sum(map(lambda type_: type_ != '0', typed_seq))
#     types_valid.update(typed_seq)
    
# annotations_num_test = 0
# types_test = set()
# for ex in tqdm(test_data):
#     typed_seq = literal_eval(ex['typed_seq'])
#     annotations_num_test += sum(map(lambda type_: type_ != '0', typed_seq))
#     types_test.update(typed_seq)

# print('train annotations:', annotations_num_train)
# print('train types num:', len(types_train) - 1)

# print('valid annotations:', annotations_num_valid)
# print('valid types num:', len(types_valid) - 1)

# print('test annotations:', annotations_num_test)
# print('test types num:', len(types_test) - 1)

# all_types = set()
# all_types.update(types_train)
# all_types.update(types_valid)
# all_types.update(types_test)
# print(len(all_types)-1)
# print(annotations_num_train+annotations_num_valid+annotations_num_test)

In [29]:
SUB_REGEX = r'typing\.|typing_extensions\.|t\.|builtins\.|collections\.'

TYPE_ALIASES = {'(?<=.*)any(?<=.*)|(?<=.*)unknown(?<=.*)': 'Any',
                '^{}$|^Dict$|^Dict\[\]$|(?<=.*)Dict\[Any, *?Any\](?=.*)|^Dict\[unknown, *Any\]$': 'dict',
                '^Set$|(?<=.*)Set\[\](?<=.*)|^Set\[Any\]$': 'set',
                '^Tuple$|(?<=.*)Tuple\[\](?<=.*)|^Tuple\[Any\]$|(?<=.*)Tuple\[Any, *?\.\.\.\](?=.*)|^Tuple\[unknown, *?unknown\]$|^Tuple\[unknown, *?Any\]$|(?<=.*)tuple\[\](?<=.*)': 'tuple',
                '^Tuple\[(.+), *?\.\.\.\]$': r'Tuple[\1]',
                '\\bText\\b': 'str',
                '^\[\]$|(?<=.*)List\[\](?<=.*)|^List\[Any\]$|^List$': 'list',
                '^\[{}\]$': 'List[dict]',
                '(?<=.*)Literal\[\'.*?\'\](?=.*)': 'Literal',
                '(?<=.*)Literal\[\d+\](?=.*)': 'Literal',
                '^Callable\[\.\.\., *?Any\]$|^Callable\[\[Any\], *?Any\]$|^Callable[[Named(x, Any)], Any]$': 'Callable',
                '^Iterator[Any]$': 'Iterator',
                '^OrderedDict[Any, *?Any]$': 'OrderedDict',
                '^Counter[Any]$': 'Counter',
                '(?<=.*)Match[Any](?<=.*)': 'Match'}

EXCLUDE_TYPES = ['Any', 'None', 'object', 'type', 'Type[Any]',
                    'Type[cls]', 'Type[type]', 'Type', 'TypeVar', 'Optional[Any]']

UBIQUITOUS_TYPES = {'str', 'int', 'list', 'bool', 'float'}
UBIQUITOUS_TYPE_IDS = set(range(len(UBIQUITOUS_TYPES)))
UBIQUITOUS_TYPE2ID = {k: v for k, v in zip(UBIQUITOUS_TYPES, UBIQUITOUS_TYPE_IDS)}

# the tokenized code is stored as str, however it is a list. Here the convertation is done 
def batch_str_to_list(batch):
    batch['untyped_seq'] = [literal_eval(seq) for seq in batch['untyped_seq']]
    batch['typed_seq'] = [literal_eval(seq) for seq in batch['typed_seq']]
    return batch

# removing undesired tokens, replacing [EOL] by \n
def preprocess_normalized_seq2seq(untyped_seq, typed_seq):
    new_untyped, new_typed = [], []
    for untyped, typed in zip(untyped_seq, typed_seq):
        if untyped == '[EOL]':
            if len(new_untyped) != 0:
                if new_untyped[-1][-1] != '\n':
                    new_untyped[-1] += '\n'
        elif untyped == '[docstring]' or untyped == '[comment]':
            continue
        else:
            new_untyped.append(untyped)
            new_typed.append(typed)
    return new_untyped, new_typed
 
# preprocess_normalized_seq2seq for batches
def batch_preprocess_normalized_seq2seq(batch):
    new_untyped_batch, new_typed_batch = [], []
    for untyped_seq, typed_seq in zip(batch['untyped_seq'], batch['typed_seq']):
        new_untyped, new_typed = preprocess_normalized_seq2seq(untyped_seq, typed_seq)
        new_untyped_batch.append(new_untyped)
        new_typed_batch.append(new_typed)
    batch['untyped_seq'] = new_untyped_batch
    batch['typed_seq'] = new_typed_batch
    return batch


def preprocess_types_as_hitiper(typed_seq):
    new_seq = []
    for t in typed_seq:
        if t == '$typing.Any$':
            new_seq.append('0')
        else:
            new_seq.append(t)
    return new_seq

def batch_preprocess_types_as_hitiper(batch):
    batch['typed_seq'] = [preprocess_types_as_hitiper(typed_seq) for typed_seq in batch['typed_seq']]
    return batch

# make one type consistent
def make_consistent(t):
    if t == '0':
        return t
    return regex.sub(SUB_REGEX, "", t[1:-1])

# make consistent an example from dataset
def ex_make_consistent(ex):
    ex['typed_seq'] = ['0' if t == '0' else make_consistent(t) for t in ex['typed_seq']]
    return ex

def remove_quote_types(t: str):
    if t == '0':
        return t
    s = regex.search(r'^\'(.+)\'$', t)
    if bool(s):
        return s.group(1)
    else:
        return t
    
def ex_remove_quote_types(ex):
    ex['typed_seq'] = ['0' if t == '0' else remove_quote_types(t) for t in ex['typed_seq']]
    return ex

def exclude_types(t):
    return '0' if t in EXCLUDE_TYPES else t

def ex_exclude_types(ex):
    ex['typed_seq'] = ['0' if t == '0' else exclude_types(t) for t in ex['typed_seq']]
    return ex

def resolve_type_alias(t: str):
    if t == '0':
        return t  
    for t_alias in TYPE_ALIASES:
        if regex.search(regex.compile(t_alias), t):
            t = regex.sub(regex.compile(t_alias), TYPE_ALIASES[t_alias], t)
    return t

def ex_resolve_type_alias(ex):
    ex['typed_seq'] = ['0' if t == '0' else resolve_type_alias(t) for t in ex['typed_seq']]
    return ex

# reducing parametric types
def reduce_parameters(t):
    nested_level = 1
    new_t = ''
    for c in t:
        if c == '[':
            nested_level += 1
            if nested_level == 3:
                new_t += '[Any]'
                
        if c == ']':
            nested_level -= 1
            if nested_level == 2:
                continue
        
        if nested_level > 2:
            continue
        else:
            new_t += c
    return new_t

def ex_reduce_parameters(ex):
    ex['typed_seq'] = ['0' if t == '0' else reduce_parameters(t) for t in ex['typed_seq']]
    return ex

def remove_trivial_annotations(untyped_seq, typed_seq):
    new_seq = []
    for token, t in zip(untyped_seq, typed_seq):
        if t == '__len__' or t == '__str__':
            new_seq.append('0')
        else:
            new_seq.append(t)
    return new_seq

def batch_remove_trivial_annotations(batch):
    batch['typed_seq'] = [remove_trivial_annotations(untyped_seq, typed_seq) for untyped_seq, typed_seq in zip(batch['untyped_seq'], batch['typed_seq'])]
    return batch

def ex_has_type(ex):
    return any(map(lambda t: t != '0', ex['typed_seq']))


class SeqType(Enum):
    INPUT_IDS = auto()
    ATTENTION_MASK = auto()
    TYPED_SEQ = auto()


class TokenizeTransform(object):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.type2id = defaultdict(lambda: len(self.type2id))
        self.type2id['0'] = -100
        self.type2id['-100'] = -100
        self.type2id.update(UBIQUITOUS_TYPE2ID)
        

    def __call__(self, batch):
        return self.create_one_batch(batch)
    
    def get_id2type(self):
        return {v: k for k, v in self.type2id.items()}

    def create_one_batch(self, batch):
        batch_encoding = self.tokenizer(
            text = [['']] * len(batch['untyped_seq']),
            text_pair = batch['untyped_seq'],
            add_special_tokens=False,
            truncation = False,
            padding = False,
            return_token_type_ids=True,
            is_split_into_words=True
        )

        
        input_ids_batch = batch_encoding.input_ids        
        attention_mask_batch = batch_encoding.attention_mask
        
        typed_seq_batch = TokenizeTransform.align_typed_seq(batch_encoding, batch)
        typed_seq_batch_wo_t_params = TokenizeTransform.batch_remove_type_parameters(typed_seq_batch)
        typed_seq_batch = self.assign_ids_to_typed_seq_batch(typed_seq_batch)
        typed_seq_batch_wo_t_params = self.assign_ids_to_typed_seq_batch(typed_seq_batch_wo_t_params)

        input_ids_batch = self.batch_split_to_chunks_pad_add_special_tokens(input_ids_batch, SeqType.INPUT_IDS)
        attention_mask_batch = self.batch_split_to_chunks_pad_add_special_tokens(attention_mask_batch, SeqType.ATTENTION_MASK)
        typed_seq_batch = self.batch_split_to_chunks_pad_add_special_tokens(typed_seq_batch, SeqType.TYPED_SEQ)
        typed_seq_batch_wo_t_params = self.batch_split_to_chunks_pad_add_special_tokens(typed_seq_batch_wo_t_params, SeqType.TYPED_SEQ)
        
        return {'input_ids_batch': input_ids_batch,
                'attention_mask_batch': attention_mask_batch,
                'typed_seq_batch': typed_seq_batch,
                'typed_seq_batch_wo_t_params': typed_seq_batch_wo_t_params
               }
    
    def align_typed_seq(batch_encoding, batch):
        typed_seq_batch = []
        for batch_idx, typed_seq in enumerate(batch['typed_seq']):
            word_ids = batch_encoding.word_ids(batch_idx)
            previous_word_idx = None
            labels = []
            for word_idx in word_ids:
                if word_idx is None or word_idx == previous_word_idx:
                    labels.append('-100')
                else:
                    labels.append(typed_seq[word_idx])
                previous_word_idx = word_idx
            
            typed_seq_batch.append(labels)
            
        return typed_seq_batch
    
    
    def assign_ids_to_typed_seq_batch(self, typed_seq_batch):
        return list(map(
            lambda typed_seq: list(map(
                lambda type_: self.type2id[type_],
                typed_seq)),
            typed_seq_batch)) 
    
    def remove_type_parameters(typed_seq):
        new_typed_seq = []
        for type_ in typed_seq:
            beg = type_.find('[')
            if beg == -1:
                new_typed_seq.append(type_)
            else:
                new_typed_seq.append(type_[:beg])
        return new_typed_seq
        
    def batch_remove_type_parameters(typed_seq_batch):
        return [TokenizeTransform.remove_type_parameters(typed_seq) for typed_seq in typed_seq_batch]
    
    def batch_split_to_chunks_pad_add_special_tokens(self, batch, seq_type: SeqType):
        chunkenized_batch = []
        for lst in batch:
            chunkenized_batch += self.split_to_chunks_pad_add_special_tokens(lst, seq_type)
        return chunkenized_batch
        
    def split_to_chunks_pad_add_special_tokens(self, lst, seq_type: SeqType):
        if seq_type == SeqType.INPUT_IDS:
            padding_token = self.tokenizer.pad_token_id
        elif seq_type == SeqType.ATTENTION_MASK:
            padding_token = 0
        else:
            padding_token = -100
            
        chunks = []
        chunk_len = self.tokenizer.model_max_length - 3
        # dividing to chunks
        for i in range(0, len(lst), chunk_len):
            el = lst[i:i + chunk_len]
            # adding special tokens
            el = self.add_special_tokens(el, seq_type)
            
            # padding
            if len(el) != self.tokenizer.model_max_length:
                el += [padding_token] * (self.tokenizer.model_max_length - len(el))
            chunks.append(el)
        return chunks
    
    def add_special_tokens(self, lst, seq_type: SeqType):        
        if seq_type == SeqType.INPUT_IDS:
            lst = [self.tokenizer.cls_token_id] + [self.tokenizer.sep_token_id] + lst + [self.tokenizer.eos_token_id]
        elif seq_type == SeqType.ATTENTION_MASK:
            lst = [1] + [1] + lst + [1]
        else:
            lst = [-100] + [-100] + lst + [-100]
        return lst
    

class PreprocessData:
    def __init__(self, tokenize_transform):
        self.tokenize_transform = tokenize_transform
    
    def get_id2type(self):
        return self.tokenize_transform.get_id2type()
    
    def get_type2id(self):
        return self.tokenize_transform.type2id
    
    def __call__(self, data):
        # doing all preprocessing
        if FLAGS['preprocess_as_hityper']:
            return data\
                .map(batch_str_to_list, batched=True, desc="Strings to lists")\
                .filter(lambda example: len(example['untyped_seq']) == len(example['typed_seq']), desc='Removing examples with unuequal sequences len')\
                .map(batch_preprocess_normalized_seq2seq, batched=True, desc='Preprocessing normalizedseq2seq')\
                .map(batch_preprocess_types_as_hitiper, batched=True, desc='Preprocessing types',)\
                .map(self.tokenize_transform, batched=True, remove_columns=data.column_names, batch_size=FLAGS['batch_size'], desc='Tokenizing for the model').with_format("torch")\
                .filter(lambda example: any(example['typed_seq_batch'] != -100), desc='Removing examples with no annotations')
        else:
            return data\
                .map(batch_str_to_list, batched=True, desc="Strings to lists")\
                .filter(lambda example: len(example['untyped_seq']) == len(example['typed_seq']), desc='Removing examples with unuequal sequences len')\
                .map(batch_preprocess_normalized_seq2seq, batched=True, desc='Preprocessing normalizedseq2seq')\
                .map(ex_make_consistent, desc='Making consistent')\
                .map(ex_remove_quote_types, desc='Removing quote types')\
                .map(ex_exclude_types, desc='Excluding types')\
                .map(ex_resolve_type_alias, desc='Resolving type alias')\
                .map(ex_reduce_parameters, desc='Reducing parameters')\
                .map(batch_remove_trivial_annotations, batched=True, desc='Removing trivial annotations')\
                .filter(lambda ex: ex_has_type(ex), desc='Removing examples with no annotations')\
                .map(self.tokenize_transform, batched=True, remove_columns=data.column_names, batch_size=FLAGS['batch_size'], desc='Tokenizing for the model').with_format("torch")\
                .filter(lambda example: any(example['typed_seq_batch'] != -100), desc='Removing examples with no annotations')



tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base", add_prefix_space=True)
tokenizer._tokenizer.post_processor = processors.BertProcessing(
    sep=("</s>", tokenizer._tokenizer.token_to_id("</s>")),
    cls=("<s>", tokenizer._tokenizer.token_to_id("<s>"))
)
new_tokens = ['[number]', '[string]']
new_tokens = set(new_tokens) - set(tokenizer.vocab.keys())
tokenizer.add_tokens(list(new_tokens))

preprocess_data = PreprocessData(TokenizeTransform(tokenizer))

In [30]:
# dataset
class ManyTypes4PyDataModule(pl.LightningDataModule):
    def __init__(self, train_path, test_path, valid_path, trial_run):
        super().__init__()
        self.train_path = train_path
        self.test_path = test_path
        self.valid_path = valid_path
        self.trial_run = trial_run


    def setup(self, stage: str):
        data = load_dataset("csv", data_files={"train": str(self.train_path), "test": str(self.test_path), "valid": str(self.valid_path)})
        if self.trial_run:
            train_data = data['train'].train_test_split(test_size=0.1, shuffle=False)['test']
            valid_data = data['valid'].train_test_split(test_size=0.1, shuffle=False)['test']
            test_data = data['test'].train_test_split(test_size=0.1, shuffle=False)['test']
        else:
            train_data = data['train']
            valid_data = data['valid']
            test_data = data['test']
        
        
        if stage == "fit":
            self.train_dataset = preprocess_data(train_data)
            self.valid_dataset = preprocess_data(valid_data)
            
        if stage == "test":
            if not hasattr(self, 'train_dataset'):
                self.train_dataset = preprocess_data(train_data)
            self.test_dataset = preprocess_data(test_data)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=FLAGS['batch_size'],
            drop_last=True,
            shuffle=True
        )

    def val_dataloader(self):
        dataloaders = [
            torch.utils.data.DataLoader(
                self.train_dataset,
                batch_size=FLAGS['batch_size'],
                drop_last=True
            ),
            torch.utils.data.DataLoader(
                self.valid_dataset,
                batch_size=FLAGS['batch_size'],
                drop_last=True
            )
            
        ]
        return dataloaders

    def test_dataloader(self):
        dataloaders = [
            torch.utils.data.DataLoader(
                self.train_dataset,
                batch_size=FLAGS['batch_size'],
                drop_last=True
            ),
            torch.utils.data.DataLoader(
                self.test_dataset,
                batch_size=FLAGS['batch_size'],
                drop_last=True
            )
            
        ]
        return dataloaders

dm = ManyTypes4PyDataModule(train_path=TRAIN_PATH, test_path=TEST_PATH, valid_path=VALID_PATH, trial_run=FLAGS['trial_run'])

In [31]:
def get_types_count(types: torch.Tensor):
    return Counter(types.cpu().numpy())

def get_is_ubiquitous_mask(types: torch.Tensor, ubiquitous_types: set):
    return torch.tensor([t.item() in ubiquitous_types for t in types], device=types.device)

def get_is_common_mask(types: torch.Tensor, types_count: Counter, ubiquitous_types: set):
    return torch.tensor([types_count[t.item()] > 100 and t.item() not in ubiquitous_types for t in types], device=types.device)

def get_is_rare_mask(types: torch.Tensor, types_count: Counter):
    return torch.tensor([types_count[t.item()] <= 100 for t in types], device=types.device)

In [32]:
# calculate the number of types after preprocessing


# def count_types(train_dl):
#     types = Counter()
#     for batch in tqdm(train_dl):
#         typed_seq_batch = batch['typed_seq_batch']
#         has_type_annotation_mask = typed_seq_batch != -100
#         type_ids = typed_seq_batch[has_type_annotation_mask]

#         types.update(type_ids.flatten().tolist())

    
#     return types


# def get_stats(dl, types_count):
#     types = []
#     for batch in tqdm(dl):
#         typed_seq_batch = batch['typed_seq_batch']
#         has_type_annotation_mask = typed_seq_batch != -100
#         type_ids = typed_seq_batch[has_type_annotation_mask]

#         types += type_ids.flatten().tolist()
        
    
#     ubiquitous = [t for t in types if t in UBIQUITOUS_TYPE_IDS]
#     common = [t for t in types if types_count[t] > 100 and t not in UBIQUITOUS_TYPE_IDS]
#     rare = [t for t in types if types_count[t] <= 100]

#     all_num = len(types)
#     all_unique = len(set(types))
    
#     ubiquitous_num = len(ubiquitous)
#     ubiquitous_unique = len(set(ubiquitous))
    
#     common_num = len(common)
#     common_unique = len(set(common))
    
#     rare_num = len(rare)
#     rare_unique = len(set(rare))
    
#     return {
#             'all_num': all_num,
#             'all_unique': all_unique,
#             'ubiquitous_num': ubiquitous_num,
#             'ubiquitous_unique': ubiquitous_unique,
#             'common_num': common_num,
#             'common_unique': common_unique,
#             'rare_num': rare_num,
#             'rare_unique': rare_unique
#         }, set(types)


# dm.setup('fit')
# types_count = count_types(dm.train_dataloader())


# train_stats, train_types = get_stats(dm.train_dataloader(), types_count)

# valid_stats, valid_types = get_stats(dm.val_dataloader()[1], types_count)

# dm.setup('test')
# test_stats, test_types = get_stats(dm.test_dataloader()[1], types_count)


# overall_unique = set()
# overall_unique.update(train_types)
# overall_unique.update(valid_types)
# overall_unique.update(test_types)

# print('Overall unique:', len(overall_unique))
# print('train stats:', train_stats)
# print('valid stats:', valid_stats)
# print('test stats:', test_stats)

  0%|          | 0/3 [00:00<?, ?it/s]

Strings to lists:   0%|          | 0/64 [00:00<?, ?ba/s]

Removing examples with unuequal sequences len:   0%|          | 0/64 [00:00<?, ?ba/s]

Preprocessing normalizedseq2seq:   0%|          | 0/64 [00:00<?, ?ba/s]

Making consistent:   0%|          | 0/63593 [00:00<?, ?ex/s]

Removing quote types:   0%|          | 0/63593 [00:00<?, ?ex/s]

Excluding types:   0%|          | 0/63593 [00:00<?, ?ex/s]

Resolving type alias:   0%|          | 0/63593 [00:00<?, ?ex/s]

Reducing parameters:   0%|          | 0/63593 [00:00<?, ?ex/s]

Removing trivial annotations:   0%|          | 0/64 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/64 [00:00<?, ?ba/s]

Tokenizing for the model:   0%|          | 0/4738 [00:00<?, ?ba/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (599 > 512). Running this sequence through the model will result in indexing errors


Removing examples with no annotations:   0%|          | 0/142 [00:00<?, ?ba/s]

Strings to lists:   0%|          | 0/7 [00:00<?, ?ba/s]

Removing examples with unuequal sequences len:   0%|          | 0/7 [00:00<?, ?ba/s]

Preprocessing normalizedseq2seq:   0%|          | 0/7 [00:00<?, ?ba/s]

Making consistent:   0%|          | 0/6936 [00:00<?, ?ex/s]

Removing quote types:   0%|          | 0/6936 [00:00<?, ?ex/s]

Excluding types:   0%|          | 0/6936 [00:00<?, ?ex/s]

Resolving type alias:   0%|          | 0/6936 [00:00<?, ?ex/s]

Reducing parameters:   0%|          | 0/6936 [00:00<?, ?ex/s]

Removing trivial annotations:   0%|          | 0/7 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/7 [00:00<?, ?ba/s]

Tokenizing for the model:   0%|          | 0/520 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/16 [00:00<?, ?ba/s]

  0%|          | 0/8157 [00:00<?, ?it/s]

  0%|          | 0/8157 [00:00<?, ?it/s]

  0%|          | 0/893 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Strings to lists:   0%|          | 0/18 [00:00<?, ?ba/s]

Removing examples with unuequal sequences len:   0%|          | 0/18 [00:00<?, ?ba/s]

Preprocessing normalizedseq2seq:   0%|          | 0/18 [00:00<?, ?ba/s]

Making consistent:   0%|          | 0/17616 [00:00<?, ?ex/s]

Removing quote types:   0%|          | 0/17616 [00:00<?, ?ex/s]

Excluding types:   0%|          | 0/17616 [00:00<?, ?ex/s]

Resolving type alias:   0%|          | 0/17616 [00:00<?, ?ex/s]

Reducing parameters:   0%|          | 0/17616 [00:00<?, ?ex/s]

Removing trivial annotations:   0%|          | 0/18 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/18 [00:00<?, ?ba/s]

Tokenizing for the model:   0%|          | 0/1313 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/41 [00:00<?, ?ba/s]

  0%|          | 0/2264 [00:00<?, ?it/s]

Overall unique: 40074
train stats: {'all_num': 886306, 'all_unique': 30803, 'ubiquitous_num': 446872, 'ubiquitous_unique': 5, 'common_num': 276017, 'common_unique': 318, 'rare_num': 163417, 'rare_unique': 30480}
valid stats: {'all_num': 94545, 'all_unique': 4736, 'ubiquitous_num': 49143, 'ubiquitous_unique': 5, 'common_num': 26735, 'common_unique': 244, 'rare_num': 18667, 'rare_unique': 4487}
test stats: {'all_num': 240860, 'all_unique': 10372, 'ubiquitous_num': 121273, 'ubiquitous_unique': 5, 'common_num': 71650, 'common_unique': 270, 'rare_num': 47937, 'rare_unique': 10097}


In [10]:
# peforms KNN lookups
class FaissKNeighbors:
    def __init__(self, is_cuda, emb_dim):
        self.is_cuda = is_cuda
        self.index = faiss.IndexFlatL2(emb_dim)
        if self.is_cuda:
            self.index = faiss.index_cpu_to_all_gpus(self.index)

    def add(self, X):
        self.index.add(X)

    def predict(self, X, k):
        distances, indices = self.index.search(X, k=k)
        return indices

    def reset(self):
        self.index.reset()

    def __del__(self):
        self.index.reset()


# returns reciprocal ranks
def get_rrs(query_labels, pred_labels, k):
    if k > pred_labels.shape[-1]:
        raise Exception('k is greater than the number of predictions for each retreival')
    targets = query_labels.unsqueeze(-1).expand(-1, k)
    preds = pred_labels[:, :k]

    all_ranks = torch.arange(1, k+1, device=preds.device).unsqueeze(0).expand(preds.shape[0], -1)
    valid_ranks = torch.where(preds == targets, all_ranks, k+1)
    rank_per_query = valid_ranks.min(dim=1).values
    filtered_ranks = rank_per_query[rank_per_query != k+1]
    reciprocal_ranks = 1 / filtered_ranks
    reciprocal_ranks = torch.cat((reciprocal_ranks, torch.zeros((preds.shape[0] - reciprocal_ranks.shape[0]), device=preds.device)))
    return reciprocal_ranks


# returns whether there is a heat for each retreival
def get_hits(query_labels, pred_labels, k):
    if k > pred_labels.shape[-1]:
        raise Exception('k is greater than the number of predictions for each retreival')

    targets = query_labels.unsqueeze(-1).expand((-1, k))
    preds = pred_labels[:, :k]
    hits = torch.eq(targets, preds).any(1)
    return hits


class MetricsCaclulator:
    def __init__(self, is_cuda, emb_dim=MODEL_PARAMS['emb_dim']):
        self.faiss = FaissKNeighbors(is_cuda, emb_dim=emb_dim)
        
    def calculate_metrics(self, reference, reference_labels,
                     query, query_labels, k=10, show_progress=True, title=''):
        self.faiss.add(reference.cpu())
        metrics = self.calculate_all(reference_labels, query, query_labels, k, show_progress)
        self.faiss.reset()
        if title:
            return {f'{title}_{k}': v for k, v in metrics.items()}
        else:
            return metrics
    
    # calculates metrics for the entire dataset
    def calculate_all(self, reference_labels, query, query_labels, k, show_progress):
        dl = torch.utils.data.DataLoader(list(zip(query, query_labels)), batch_size=FLAGS['batch_size_for_knn'])
        reciprocal_ranks, top1_hits, top3_hits, top5_hits, top10_hits = [], [], [], [], []
        for query_batch, query_labels_batch in tqdm(dl, desc='Calculating Metrics', disable=not show_progress):
            results_of_batch = self.calculate_batch(reference_labels=reference_labels, query=query_batch, query_labels=query_labels_batch, k=k)
            reciprocal_ranks.append(results_of_batch[0])
            top1_hits.append(results_of_batch[1])
            top3_hits.append(results_of_batch[2])
            top5_hits.append(results_of_batch[3])
            top10_hits.append(results_of_batch[4])
        
        mrr = torch.cat(reciprocal_ranks).mean()
        hit_rate_top1 = torch.cat(top1_hits).to(torch.float32).mean()
        hit_rate_top3 = torch.cat(top3_hits).to(torch.float32).mean()
        hit_rate_top5 = torch.cat(top5_hits).to(torch.float32).mean()
        hit_rate_top10 = torch.cat(top10_hits).to(torch.float32).mean()
        
        del reciprocal_ranks
        del top1_hits
        del top3_hits
        del top5_hits
        del top10_hits
        gc.collect()
        
        metrics = {
            f'MRR@{k}': mrr,
            'hit_rate_top1': hit_rate_top1,
            'hit_rate_top3': hit_rate_top3,
            'hit_rate_top5': hit_rate_top5,
            'hit_rate_top10': hit_rate_top10
        }
        return metrics

    # calculate for a batch
    def calculate_batch(self, reference_labels, query, query_labels, k):
        # N, K
        pred_idxs = self.faiss.predict(query.cpu(), k=k).to(query_labels.device)
        pred_labels = reference_labels.expand(pred_idxs.shape[0], -1).gather(index=pred_idxs, dim=1)

        reciprocal_ranks = get_rrs(query_labels, pred_labels, k)
        top1_hits = get_hits(query_labels, pred_labels, 1)
        top3_hits = get_hits(query_labels, pred_labels, 3)
        top5_hits = get_hits(query_labels, pred_labels, 5)
        top10_hits = get_hits(query_labels, pred_labels, 10)
        
        return reciprocal_ranks, top1_hits, top3_hits, top5_hits, top10_hits
    
metrics_calculator = MetricsCaclulator(is_cuda=True)
# function for metrics calculation
calculate_metrics = metrics_calculator.calculate_metrics

In [13]:
class Model(pl.LightningModule):
    def __init__(self, margin, loss_type, token_embeddings_size, learning_rate, batch_size, seed=None):
        # setting up the model
        super().__init__()
        self.save_hyperparameters()

        self.semi_hard_miner = miners.TripletMarginMiner(margin=self.hparams.margin, type_of_triplets='semihard')
        self.hard_miner = miners.TripletMarginMiner(margin=self.hparams.margin, type_of_triplets='hard')
        self.loss_fn = losses.TripletMarginLoss(margin=self.hparams.margin)
        
        self.train_semi_hard_loss = MeanMetric(nan_strategy='error')
        self.train_hard_loss = MeanMetric(nan_strategy='error')
        

        self.model = AutoModel.from_pretrained("microsoft/codebert-base")
        self.model.resize_token_embeddings(token_embeddings_size)
        for p in self.model.parameters():
            p.requires_grad = True


    def forward(self, input_ids_batch, attention_mask_batch):
        embs = self.model(input_ids=input_ids_batch, attention_mask=attention_mask_batch).last_hidden_state
        return embs


    def training_step(self, batch, batch_nb):
        input_ids_batch = batch['input_ids_batch']
        attention_mask_batch = batch['attention_mask_batch']
        typed_seq_batch = batch['typed_seq_batch']

        embs = self(input_ids_batch, attention_mask_batch)
        has_type_annotation_mask = typed_seq_batch != -100
        embs_with_type_annotation = embs[has_type_annotation_mask]
        type_ids = typed_seq_batch[has_type_annotation_mask]
        
        semi_hard_triplets = self.semi_hard_miner(embs_with_type_annotation, type_ids)
        hard_triplets = self.hard_miner(embs_with_type_annotation, type_ids)
        
        semi_hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, semi_hard_triplets)
        hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, hard_triplets)

        if self.hparams.loss_type == 'SEMI_HARD':
            loss = semi_hard_loss
        else:
            loss = hard_loss

        
        self.train_semi_hard_loss(semi_hard_loss)
        self.train_hard_loss(hard_loss)
        
        self.log('train_semi_hard_loss', self.train_semi_hard_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_hard_loss', self.train_hard_loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    # calculating and storing all embeddings
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        input_ids_batch = batch['input_ids_batch']
        attention_mask_batch = batch['attention_mask_batch']
        typed_seq_batch = batch['typed_seq_batch']
        typed_seq_batch_wo_t_params = batch['typed_seq_batch_wo_t_params']

        embs = self(input_ids_batch, attention_mask_batch).detach()

        has_type_annotation_mask = typed_seq_batch != -100
        embs_with_type_annotation = embs[has_type_annotation_mask]
        type_ids = typed_seq_batch[has_type_annotation_mask]
        type_ids_wo_t_params = typed_seq_batch_wo_t_params[has_type_annotation_mask]

        if dataloader_idx == 0:
            if batch_idx == 0:
                self.feature_bank = [embs_with_type_annotation.cpu()]
                self.target_bank = [type_ids.cpu()]
                self.target_bank_wo_t_params = [type_ids_wo_t_params.cpu()]
            else:
                self.feature_bank.append(embs_with_type_annotation.cpu())
                self.target_bank.append(type_ids.cpu())
                self.target_bank_wo_t_params.append(type_ids_wo_t_params.cpu())
        else:
            if batch_idx == 0:
                self.query_embeddings = [embs_with_type_annotation.cpu()]
                self.query_labels = [type_ids.cpu()]
                self.query_labels_wo_t_params = [type_ids_wo_t_params.cpu()]
                self.query_positions = [torch.argwhere(has_type_annotation_mask)[:, 1].flatten().cpu()]
            else:
                self.query_embeddings.append(embs_with_type_annotation.cpu())
                self.query_labels.append(type_ids.cpu())
                self.query_labels_wo_t_params.append(type_ids_wo_t_params.cpu())
                self.query_positions.append(torch.argwhere(has_type_annotation_mask)[:, 1].flatten().cpu())
            

            semi_hard_triplets = self.semi_hard_miner(embs_with_type_annotation, type_ids)
            hard_triplets = self.hard_miner(embs_with_type_annotation, type_ids)

            semi_hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, semi_hard_triplets)
            hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, hard_triplets)


            self.log('valid_semi_hard_loss', semi_hard_loss, prog_bar=True)
            self.log('valid_hard_loss', hard_loss, prog_bar=True)

    # calculating metrics based on stored embeddings
    def on_validation_epoch_end(self):
        self.feature_bank = torch.cat(self.feature_bank)
        self.target_bank = torch.cat(self.target_bank)
        self.target_bank_wo_t_params = torch.cat(self.target_bank_wo_t_params)
        self.query_embeddings = torch.cat(self.query_embeddings)
        self.query_labels = torch.cat(self.query_labels)
        self.query_labels_wo_t_params = torch.cat(self.query_labels_wo_t_params)
        
        types_count = get_types_count(self.target_bank)
        ubiquitous_types = UBIQUITOUS_TYPE_IDS
        
        is_ubiquitous_mask = get_is_ubiquitous_mask(self.query_labels, ubiquitous_types)
        ubiquitous_embeddings = self.query_embeddings[is_ubiquitous_mask]
        ubiquitous_labels = self.query_labels[is_ubiquitous_mask]
        
        is_common_mask = get_is_common_mask(self.query_labels, types_count, ubiquitous_types)
        common_embeddings = self.query_embeddings[is_common_mask]
        common_labels = self.query_labels[is_common_mask]
        common_labels_wo_t_params = self.query_labels_wo_t_params[is_common_mask]
        
        is_rare_mask = get_is_rare_mask(self.query_labels, types_count)
        rare_embeddings = self.query_embeddings[is_rare_mask]
        rare_labels = self.query_labels[is_rare_mask]
        rare_labels_wo_t_params = self.query_labels_wo_t_params[is_rare_mask]
        
        
        all_metrics = calculate_metrics(
            query=self.query_embeddings, 
            query_labels=self.query_labels, 
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='valid_all'
        )
        all_metrics_wo_t_params = calculate_metrics(
            query=self.query_embeddings, 
            query_labels=self.query_labels_wo_t_params, 
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='valid_all_wo_t_params'
        )
        ubiquitos_metrics = calculate_metrics(
            query=ubiquitous_embeddings,
            query_labels=ubiquitous_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='valid_ubiquitos'
        )
        common_metrics = calculate_metrics(
            query=common_embeddings,
            query_labels=common_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='valid_common'
        )
        common_metrics_wo_t_params = calculate_metrics(
            query=common_embeddings,
            query_labels=common_labels_wo_t_params,
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='valid_common_wo_t_params'
        )
        rare_metrics = calculate_metrics(
            query=rare_embeddings,
            query_labels=rare_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='valid_rare'
        )
        rare_metrics_wo_t_params = calculate_metrics(
            query=rare_embeddings,
            query_labels=rare_labels_wo_t_params,
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='valid_rare_wo_t_params'
        )
        
        combined_metrics = all_metrics
        combined_metrics.update(all_metrics_wo_t_params)
        combined_metrics.update(ubiquitos_metrics)
        combined_metrics.update(common_metrics)
        combined_metrics.update(common_metrics_wo_t_params)
        combined_metrics.update(rare_metrics)
        combined_metrics.update(rare_metrics_wo_t_params)
        
        self.query_positions = torch.cat(self.query_positions)
        pos_mrr_metrics = {}
        for pos in range(MODEL_PARAMS['model_max_length']):
            pos_embeddings = self.query_embeddings[self.query_positions == pos]
            pos_labels = self.query_labels[self.query_positions == pos]
            if len(pos_labels) != 0:
                pos_metrics = calculate_metrics(
                    query = pos_embeddings,
                    query_labels = pos_labels,
                    reference = self.feature_bank,
                    reference_labels=self.target_bank,
                    show_progress=False
                )
                pos_mrr_metrics[pos] = pos_metrics['MRR@10'].item()
        self.logger.experiment.log_curve('validation MRR@10 on pos', x=pos_mrr_metrics.keys(), y=pos_mrr_metrics.values())
            
        
        self.log_dict(
            combined_metrics,
            on_step=False,
            on_epoch=True
        )
        
        del self.feature_bank
        del self.target_bank
        del self.target_bank_wo_t_params
        del self.query_embeddings
        del self.query_labels
        del self.query_labels_wo_t_params
        del self.query_positions

        gc.collect()


    def test_step(self, batch, batch_idx, dataloader_idx=0):
        input_ids_batch = batch['input_ids_batch']
        attention_mask_batch = batch['attention_mask_batch']
        typed_seq_batch = batch['typed_seq_batch']
        typed_seq_batch_wo_t_params = batch['typed_seq_batch_wo_t_params']

        embs = self(input_ids_batch, attention_mask_batch).detach()

        has_type_annotation_mask = typed_seq_batch != -100
        embs_with_type_annotation = embs[has_type_annotation_mask]
        type_ids = typed_seq_batch[has_type_annotation_mask]
        type_ids_wo_t_params = typed_seq_batch_wo_t_params[has_type_annotation_mask]

        if dataloader_idx == 0:
            if batch_idx == 0:
                self.feature_bank = [embs_with_type_annotation.cpu()]
                self.target_bank = [type_ids.cpu()]
                self.target_bank_wo_t_params = [type_ids_wo_t_params.cpu()]
            else:
                self.feature_bank.append(embs_with_type_annotation.cpu())
                self.target_bank.append(type_ids.cpu())
                self.target_bank_wo_t_params.append(type_ids_wo_t_params.cpu())
        else:
            if batch_idx == 0:
                self.query_embeddings = [embs_with_type_annotation.cpu()]
                self.query_labels = [type_ids.cpu()]
                self.query_labels_wo_t_params = [type_ids_wo_t_params.cpu()]
                self.query_positions = [torch.argwhere(has_type_annotation_mask)[:, 1].flatten().cpu()]
            else:
                self.query_embeddings.append(embs_with_type_annotation.cpu())
                self.query_labels.append(type_ids.cpu())
                self.query_labels_wo_t_params.append(type_ids_wo_t_params.cpu())
                self.query_positions.append(torch.argwhere(has_type_annotation_mask)[:, 1].flatten().cpu())
            

            semi_hard_triplets = self.semi_hard_miner(embs_with_type_annotation, type_ids)
            hard_triplets = self.hard_miner(embs_with_type_annotation, type_ids)

            semi_hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, semi_hard_triplets)
            hard_loss = self.loss_fn(embs_with_type_annotation, type_ids, hard_triplets)


            self.log('test_semi_hard_loss', semi_hard_loss, prog_bar=True)
            self.log('test_hard_loss', hard_loss, prog_bar=True)

    def on_test_epoch_end(self):
        self.feature_bank = torch.cat(self.feature_bank)
        self.target_bank = torch.cat(self.target_bank)
        self.target_bank_wo_t_params = torch.cat(self.target_bank_wo_t_params)
        self.query_embeddings = torch.cat(self.query_embeddings)
        self.query_labels = torch.cat(self.query_labels)
        self.query_labels_wo_t_params = torch.cat(self.query_labels_wo_t_params)
        
        types_count = get_types_count(self.target_bank)
        ubiquitous_types = UBIQUITOUS_TYPE_IDS
        
        is_ubiquitous_mask = get_is_ubiquitous_mask(self.query_labels, ubiquitous_types)
        ubiquitous_embeddings = self.query_embeddings[is_ubiquitous_mask]
        ubiquitous_labels = self.query_labels[is_ubiquitous_mask]
        
        is_common_mask = get_is_common_mask(self.query_labels, types_count, ubiquitous_types)
        common_embeddings = self.query_embeddings[is_common_mask]
        common_labels = self.query_labels[is_common_mask]
        common_labels_wo_t_params = self.query_labels_wo_t_params[is_common_mask]
        
        is_rare_mask = get_is_rare_mask(self.query_labels, types_count)
        rare_embeddings = self.query_embeddings[is_rare_mask]
        rare_labels = self.query_labels[is_rare_mask]
        rare_labels_wo_t_params = self.query_labels_wo_t_params[is_rare_mask]
        
        
        all_metrics = calculate_metrics(
            query=self.query_embeddings, 
            query_labels=self.query_labels, 
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='test_all'
        )
        all_metrics_wo_t_params = calculate_metrics(
            query=self.query_embeddings, 
            query_labels=self.query_labels_wo_t_params, 
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='test_all_wo_t_params'
        )
        ubiquitos_metrics = calculate_metrics(
            query=ubiquitous_embeddings,
            query_labels=ubiquitous_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='test_ubiquitos'
        )
        common_metrics = calculate_metrics(
            query=common_embeddings,
            query_labels=common_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='test_common'
        )
        common_metrics_wo_t_params = calculate_metrics(
            query=common_embeddings,
            query_labels=common_labels_wo_t_params,
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='test_common_wo_t_params'
        )
        rare_metrics = calculate_metrics(
            query=rare_embeddings,
            query_labels=rare_labels,
            reference=self.feature_bank,
            reference_labels=self.target_bank,
            title='test_rare'
        )
        rare_metrics_wo_t_params = calculate_metrics(
            query=rare_embeddings,
            query_labels=rare_labels_wo_t_params,
            reference=self.feature_bank,
            reference_labels=self.target_bank_wo_t_params,
            title='test_rare_wo_t_params'
        )
        
        combined_metrics = all_metrics
        combined_metrics.update(all_metrics_wo_t_params)
        combined_metrics.update(ubiquitos_metrics)
        combined_metrics.update(common_metrics)
        combined_metrics.update(common_metrics_wo_t_params)
        combined_metrics.update(rare_metrics)
        combined_metrics.update(rare_metrics_wo_t_params)
        
        self.query_positions = torch.cat(self.query_positions)
        pos_mrr_metrics = {}
        for pos in range(MODEL_PARAMS['model_max_length']):
            pos_embeddings = self.query_embeddings[self.query_positions == pos]
            pos_labels = self.query_labels[self.query_positions == pos]
            if len(pos_labels) != 0:
                pos_metrics = calculate_metrics(
                    query = pos_embeddings,
                    query_labels = pos_labels,
                    reference = self.feature_bank,
                    reference_labels=self.target_bank,
                    show_progress=False
                )
                pos_mrr_metrics[pos] = pos_metrics['MRR@10'].item()
        self.logger.experiment.log_curve('test MRR@10 on pos', x=pos_mrr_metrics.keys(), y=pos_mrr_metrics.values())
            
        
        self.log_dict(
            combined_metrics,
            on_step=False,
            on_epoch=True
        )
        
        del self.feature_bank
        del self.target_bank
        del self.target_bank_wo_t_params
        del self.query_embeddings
        del self.query_labels
        del self.query_labels_wo_t_params
        del self.query_positions

        gc.collect()


    def predict_step(self, batch, batch_nb):
        input_ids_batch = batch['input_ids_batch']
        attention_mask_batch = batch['attention_mask_batch']

        embs = self(input_ids_batch, attention_mask_batch)
        return embs


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

In [14]:
if FLAGS['CHECKPOINT_PATH']:
    if MODEL_PARAMS['keep_same']:
        model = Model.load_from_checkpoint(
            FLAGS['CHECKPOINT_PATH'],
            SEED=FLAGS['SEED']
        )
    else:
        model = Model.load_from_checkpoint(
            FLAGS['CHECKPOINT_PATH'], 
            margin=MODEL_PARAMS['margin'],
            loss_type=MODEL_PARAMS['loss_type'],
            learning_rate=MODEL_PARAMS['lr'],
            batch_size=FLAGS['batch_size'],
            seed=FLAGS['SEED']
        )
else:
    model = Model(
        margin=MODEL_PARAMS['margin'],
        loss_type=MODEL_PARAMS['loss_type'],
        token_embeddings_size=len(tokenizer),
        learning_rate=MODEL_PARAMS['lr'],
        batch_size=FLAGS['batch_size'],
        seed=FLAGS['SEED']
    )


# which metric to monitor
METRIC2MONITOR = 'valid_all_MRR@10'
checkpoint_callback = ModelCheckpoint(dirpath=f"{FLAGS['EXPERIMENT_NAME']}_checkpoints", save_top_k=1, monitor=METRIC2MONITOR, mode='max')

# setting up the logger
comet_logger = CometLogger(
    api_key=COMET_API_KEY,
    project_name=FLAGS['PROJECT_NAME'],
    experiment_name=FLAGS['EXPERIMENT_NAME'],
)


# setting up the trainer
trainer = pl.Trainer(
    accelerator="gpu",
    callbacks=
    [
        checkpoint_callback,
    ],
    logger=comet_logger,
    max_epochs=FLAGS['max_epochs'],
    log_every_n_steps=40,
    check_val_every_n_epoch=1,
    enable_progress_bar=True,
    num_sanity_val_steps=0
)

Downloading pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

In [None]:
if FLAGS['stage'] == 'train':
    trainer.fit(model, dm)
elif FLAGS['stage'] == 'valid':
    dm.setup('fit')
    trainer.validate(model, dm)
elif FLAGS['stage'] == 'test':
    dm.setup('test')
    trainer.test(model, dm)

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-1adbca1ba80c7406/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

  csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
  csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
  csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)


Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-1adbca1ba80c7406/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Strings to lists:   0%|          | 0/7 [00:00<?, ?ba/s]

Removing examples with unuequal sequences len:   0%|          | 0/7 [00:00<?, ?ba/s]

Preprocessing normalizedseq2seq:   0%|          | 0/7 [00:00<?, ?ba/s]

Making consistent:   0%|          | 0/6362 [00:00<?, ?ex/s]

Removing quote types:   0%|          | 0/6362 [00:00<?, ?ex/s]

Excluding types:   0%|          | 0/6362 [00:00<?, ?ex/s]

Resolving type alias:   0%|          | 0/6362 [00:00<?, ?ex/s]

Reducing parameters:   0%|          | 0/6362 [00:00<?, ?ex/s]

Removing trivial annotations:   0%|          | 0/7 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/7 [00:00<?, ?ba/s]

Tokenizing for the model:   0%|          | 0/480 [00:00<?, ?ba/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (571 > 512). Running this sequence through the model will result in indexing errors


Removing examples with no annotations:   0%|          | 0/15 [00:00<?, ?ba/s]

Strings to lists:   0%|          | 0/1 [00:00<?, ?ba/s]

Removing examples with unuequal sequences len:   0%|          | 0/1 [00:00<?, ?ba/s]

Preprocessing normalizedseq2seq:   0%|          | 0/1 [00:00<?, ?ba/s]

Making consistent:   0%|          | 0/694 [00:00<?, ?ex/s]

Removing quote types:   0%|          | 0/694 [00:00<?, ?ex/s]

Excluding types:   0%|          | 0/694 [00:00<?, ?ex/s]

Resolving type alias:   0%|          | 0/694 [00:00<?, ?ex/s]

Reducing parameters:   0%|          | 0/694 [00:00<?, ?ex/s]

Removing trivial annotations:   0%|          | 0/1 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/1 [00:00<?, ?ba/s]

Tokenizing for the model:   0%|          | 0/53 [00:00<?, ?ba/s]

Removing examples with no annotations:   0%|          | 0/2 [00:00<?, ?ba/s]

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/gfx73/typebert4py/d10fff27c78c4e86aa55e5c1877d9b16



Training: 0it [00:00, ?it/s]

In [None]:
# # saving previous runs weights
# import os
# from pathlib import Path
# from distutils.dir_util import copy_tree
# INPUT_DIR = Path('/kaggle/input/codebertdsl-2')
# OUTPUT_DIR = Path('/kaggle/working')
# for p in os.listdir(INPUT_DIR):
#     if '_checkpoints' in p:
#         source_dir = INPUT_DIR / p
#         dest_dir = OUTPUT_DIR / p
#         if not os.path.exists(dest_dir):
#             os.mkdir(str(dest_dir))
#             copy_tree(str(source_dir), str(dest_dir))