# MultiRD

In [None]:
import os
import sys
import shutil
from pathlib import Path
from typing import Optional

# set GDRIVE_SAVE_DIR="" to disable export of model checkpoint file to Google Drive
# if GDRIVE_SAVE_DIR is specified, the first path component is required to be 'MyDrive'
GDRIVE_SAVE_DIR = "MyDrive/CS-GY 6953 DL/DL final project/checkpoints"
# GDRIVE_SAVE_DIR = ""

class Uploader:

    def __init__(self, local_gdrive_save_path: Optional[str] = None):
        self.local_gdrive_save_path = local_gdrive_save_path
        self.conservation_deck = {}
    
    def is_enabled(self) -> bool:
        return True if self.local_gdrive_save_path else False
    
    def replace(self, saved_file: Path, conserve: Optional[str] = None):
        if not conserve:
            return
        previous = self.conservation_deck.get(conserve, None)
        if previous is not None:
            try:
                os.remove(previous)
            except FileNotFoundError:
                pass
        self.conservation_deck[conserve] = saved_file
                
    def upload_file(self, src_file: Path, dst_path: str, suppress_error: bool = False, conserve: Optional[str] = None) -> Optional[str]:
        if not self.local_gdrive_save_path:
            return
        dst_file = Path(self.local_gdrive_save_path) / dst_path
        try:
            dst_file.parent.mkdir(exist_ok=True, parents=True)
            shutil.copyfile(src_file, dst_file)
            self.replace(dst_file, conserve=conserve)
            return str(dst_file)
        except Exception as e:
            if suppress_error:
                print(f"suppressing save error {type(e)} {e} on file {src_file} -> {dst_file}", file=sys.stderr)
            else:
                raise

    def upload_checkpoint(self, checkpoint_file: Path, infix: str) -> Optional[str]:
        if not self.local_gdrive_save_path:
            return
        filename = f"{checkpoint_file.stem}-{infix}{checkpoint_file.suffix}"
        dst_file = self.upload_file(checkpoint_file, filename)
        return dst_file

    @staticmethod
    def prepare_mount() -> 'Uploader':
        save_path_root = "/content/gdrive"
        local_save_root = str(os.path.join(save_path_root, GDRIVE_SAVE_DIR))
        if GDRIVE_SAVE_DIR:
            try:
                # noinspection PyUnresolvedReferences
                from google.colab import drive
                drive.mount(save_path_root)
                return Uploader(local_save_root)
            except Exception as e:
                if isinstance(e, ImportError):
                    print("(not saving because not in colab environment)")
                else:
                    print("not saving to gdrive due to", type(e).__name__, e)
        return Uploader()  # not enabled

UPLOADER = Uploader.prepare_mount()

In [None]:
# unpack data
!(test -d data || (gdown "1zjrLaaKR9Pf-DUmjkoptRyG4SBasgked" && unzip -q "english-rd-data.zip"))
!ls data
DATA_PATH = './data'

data.py

In [None]:
import torch.cuda
import torch.utils.data
import numpy as np
from typing import Callable
from typing import TextIO
from typing import Any
from typing import NamedTuple
from typing import Sequence
from json import load as json_load


device = torch.device('cuda:0') if torch.cuda.is_available() else "cpu"

class MyDataset(torch.utils.data.Dataset): 
    def __init__(self, instances):
        self.instances = instances
    
    def __len__(self):
        return len(self.instances)
        
    def __getitem__(self, index):
        return self.instances[index]
 
def data2index(data_x, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency):
    """
    {
        "word": "restlessly",
        "lexnames": [
            "adv.all"
        ],
        "root_affix": [
            "ly"
        ],
        "sememes": [
            "rash"
        ],
        "definitions": "in a restless manner unquietly"
    }
    """
    data_x_idx = list()
    for instance in data_x:
        sememe_idx = [sememe2index[se] for se in instance['sememes']]
        lexname_idx = [lexname2index[ln] for ln in instance['lexnames']]
        rootaffix_idx = [rootaffix2index[ra] for ra in instance['root_affix'] if rootaffix_freq[ra]>=frequency]
        def_word_idx = list()
        def_words = instance['definitions'].strip().split()
        if len(def_words) > 0:
            for def_word in def_words:
                if def_word in word2index and def_word!=instance['word']:
                    def_word_idx.append(word2index[def_word])
                else:
                    def_word_idx.append(word2index['<OOV>'])
            data_x_idx.append({'word': word2index[instance['word']], 'lexnames': lexname_idx, 'root_affix':rootaffix_idx, 'sememes': sememe_idx, 'definition_words': def_word_idx})
        else:
            pass #print(instance['word'], instance['definitions']) # some is null
    return data_x_idx


def readlines(f: TextIO):
    return f.readlines()


def load_data_file(filename: str, transform: Callable[[TextIO], Any]):
    with open(os.path.join(DATA_PATH, filename)) as ifile:
        if transform is not None:
            return transform(ifile)
        return ifile.read()

def load_data_json(filename: str):
    return load_data_file(filename, transform=json_load)

def load_data_lines(filename: str):
    return load_data_file(filename, transform=readlines)


class WordMappings(NamedTuple):
    
    word2index: dict[str, int]
    index2word: list[str]
    word2vec: np.ndarray
    
    def describe(self) -> dict[str, Any]:
        return {
            "word2index": len(self.word2index),
            "index2word": len(self.index2word),
            "word2vec": f"{self.word2vec.shape} {self.word2vec.dtype}",
        }
    
    def expand(self, phrase_indexes: Sequence[int]) -> str:
        words = [self.index2word[word_i] for word_i in phrase_indexes]
        stop = len(words)
        for i in range(len(words)):
            if words[i] == '<PAD>':
                stop = i
                break
            if words[i] == '<OOV>':
                words[i] = '#'
        return " ".join(words[:stop])
            


class CoreMappings(NamedTuple):
    
    index2sememe: list[str]
    index2lexname: list[str]
    index2rootaffix: list[str]
    
    def describe(self) -> dict[str, int]:
        return dict((k, len(v)) for k, v in self._asdict().items())


class Label(NamedTuple):
    
    size: int
    lexname_size: int
    rootaffix_size: int
    sememe_size: int


class Indexes(NamedTuple):
    
    data_train_idx: list[int]
    data_dev_idx: list[int]
    data_test_500_seen_idx: list[int]
    data_test_500_unseen_idx: list[int]
    data_defi_c_idx: list[int]
    data_desc_c_idx: list[int]


class LoadedData(NamedTuple):
    
    word_mappings: WordMappings
    core_mappings: CoreMappings
    label: Label
    indexes: Indexes


def load_data(frequency: int) -> LoadedData:
    print('Loading dataset...')
    data_train = load_data_json("data_train.json")
    data_dev = load_data_json("data_dev.json")
    data_test_500_rand1_seen = load_data_json("data_test_500_rand1_seen.json")
    data_test_500_rand1_unseen = load_data_json("data_test_500_rand1_unseen.json") #data_test_500_others
    data_defi_c = load_data_json("data_defi_c.json")
    data_desc_c = load_data_json("data_desc_c.json")
    lines = load_data_lines("target_words.txt")
    target_words = [line.strip() for line in lines]
    label_size = len(target_words)+2
    print('target_words (include <PAD><OOV>): ', label_size)
    lines = load_data_lines("lexname_all.txt")
    lexname_all = [line.strip() for line in lines]
    label_lexname_size = len(lexname_all)
    print('label_lexname_size: ', label_lexname_size)
    lines = load_data_lines("root_affix_freq.txt")
    rootaffix_freq = {}
    for line in lines:
        rootaffix_freq[line.strip().split()[0]] = int(line.strip().split()[1])
    lines = load_data_lines("rootaffix_all.txt")
    rootaffix_all = [line.strip() for line in lines]
    lines = load_data_lines("sememes_all.txt")
    sememes_all = [line.strip() for line in lines]
    label_sememe_size = len(sememes_all)+1
    print('label_sememe_size: ', label_sememe_size)
    vec_inuse = load_data_json("vec_inuse.json")
    vocab = list(vec_inuse)
    vocab_size = len(vocab)+2
    print('vocab (embeddings in use)(include <PAD><OOV>): ', vocab_size)
    word2index: dict[str, int] = dict()
    index2word: list[str] = list()
    word2index['<PAD>'] = 0
    word2index['<OOV>'] = 1
    index2word.extend(['<PAD>', '<OOV>'])
    index2word.extend(vocab)
    word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32)
    for wd in target_words: 
        index = len(word2index)
        word2index[wd] = index
        word2vec[index, :] = vec_inuse[wd]
    for wd in vocab:
        if wd in target_words:
            continue
        index = len(word2index)
        word2index[wd] = index
        word2vec[index, :] = vec_inuse[wd]
    sememe2index = dict()
    index2sememe = list()
    for sememe in sememes_all:
        sememe2index[sememe] = len(sememe2index)
        index2sememe.append(sememe)
    lexname2index = dict()
    index2lexname = list()
    for ln in lexname_all:
        lexname2index[ln] = len(lexname2index)
        index2lexname.append(ln)
    rootaffix2index = dict()
    index2rootaffix = list()
    for ra in rootaffix_all:
        if rootaffix_freq[ra] >= frequency:
            rootaffix2index[ra] = len(rootaffix2index)
            index2rootaffix.append(ra)
    label_rootaffix_size = len(index2rootaffix)
    print('label_rootaffix_size: ', label_rootaffix_size)
    data_train_idx = data2index(data_train, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency)
    print('data_train size: %d'%len(data_train_idx))
    data_dev_idx = data2index(data_dev, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency)
    print('data_dev size: %d'%len(data_dev_idx))
    data_test_500_seen_idx = data2index(data_test_500_rand1_seen, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency) 
    print('data_test_seen size: %d'%len(data_test_500_seen_idx))
    data_test_500_unseen_idx = data2index(data_test_500_rand1_unseen, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency) 
    print('data_test_unseen size: %d'%len(data_test_500_unseen_idx))
    data_defi_c_idx = data2index(data_defi_c, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency)
    data_desc_c_idx = data2index(data_desc_c, word2index, sememe2index, lexname2index, rootaffix2index, rootaffix_freq, frequency)    
    print('data_desc size: %d'%len(data_desc_c_idx))
    return LoadedData(
        word_mappings=WordMappings(word2index, index2word, word2vec),
        core_mappings=CoreMappings(index2sememe, index2lexname, index2rootaffix),
        label=Label(label_size, label_lexname_size, label_rootaffix_size, label_sememe_size),
        indexes=Indexes(data_train_idx, data_dev_idx, data_test_500_seen_idx, data_test_500_unseen_idx, data_defi_c_idx, data_desc_c_idx),
    )
    # return word2index, index2word, word2vec, (index2sememe, index2lexname, index2rootaffix), (label_size, label_lexname_size, label_rootaffix_size, label_sememe_size), (data_train_idx, data_dev_idx, data_test_500_seen_idx, data_test_500_unseen_idx, data_defi_c_idx, data_desc_c_idx)



    
def build_sentence_numpy(sentences):
    max_length = max([len(sentence) for sentence in sentences])
    sentence_numpy = np.zeros((len(sentences), max_length), dtype=np.int64)
    for i in range(len(sentences)):
        sentence_numpy[i, 0:len(sentences[i])] = np.array(sentences[i])
    return sentence_numpy
    

def label_multihot(labels, num):
    sm = np.zeros((len(labels), num), dtype=np.float32)
    for i in range(len(labels)):
        for s in labels[i]:
            if s >= num:
                break
            sm[i, s] = 1
    return sm
    
def my_collate_fn(batch):
    words = [instance['word'] for instance in batch]
    definition_words = [instance['definition_words'] for instance in batch]
    words_t = torch.tensor(np.array(words), dtype=torch.int64, device=device)
    definition_words_t = torch.tensor(build_sentence_numpy(definition_words), dtype=torch.int64, device=device)
    return words_t, definition_words_t
    
def word2feature(dataset, word_num, feature_num, feature_name):
    max_feature_num = max([len(instance[feature_name]) for instance in dataset])
    ret = np.zeros((word_num, max_feature_num), dtype=np.int64)
    ret.fill(feature_num)
    for instance in dataset:
        if ret[instance['word'], 0] != feature_num: 
            continue # this target_words has been given a feature mapping, because same word with different definition in dataset
        feature = instance[feature_name]
        ret[instance['word'], :len(feature)] = np.array(feature)
    return torch.tensor(ret, dtype=torch.int64, device=device)
    
def mask_noFeature(label_size, wd2fea, feature_num):
    mask_nofea = torch.zeros(label_size, dtype=torch.float32, device=device)
    for i in range(label_size):
        feas = set(wd2fea[i].detach().cpu().numpy().tolist()) - {feature_num}
        if len(feas)==0:
            mask_nofea[i] = 1
    return mask_nofea


model.py

In [None]:
import torch

class BiLSTM(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
    
    def forward(self, x, x_len):
        # x: T(bat, len, emb) float32
        # x_len: T(bat) int64
        _, x_len_sort_idx = torch.sort(-x_len)
        _, x_len_unsort_idx = torch.sort(x_len_sort_idx)
        x = x[x_len_sort_idx]
        x_len = x_len[x_len_sort_idx]
        x_len = x_len.to("cpu")
        x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True)
        # ht: T(num_layers*2, bat, hid) float32
        # ct: T(num_layers*2, bat, hid) float32
        h_packed, (ht, ct) = self.lstm(x_packed, None)
        ht = ht[:, x_len_unsort_idx, :]
        ct = ct[:, x_len_unsort_idx, :]
        # h: T(bat, len, hid*2) float32
        h, _ = torch.nn.utils.rnn.pad_packed_sequence(h_packed, batch_first=True)
        h = h[x_len_unsort_idx]
        return h, (ht, ct)
        
class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, layers, class_num, sememe_num, lexname_num, rootaffix_num):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.layers = layers
        self.class_num = class_num
        self.sememe_num = sememe_num
        self.lexname_num = lexname_num
        self.rootaffix_num = rootaffix_num
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embed_dim, padding_idx=0, max_norm=5, sparse=True)
        self.embedding.weight.requires_grad = False
        self.embedding_dropout = torch.nn.Dropout()
        self.encoder = BiLSTM(self.embed_dim, self.hidden_dim, self.layers)
        self.fc = torch.nn.Linear(self.hidden_dim*2, self.embed_dim)
        self.fc_s = torch.nn.Linear(self.hidden_dim*2, self.sememe_num)
        self.fc_l = torch.nn.Linear(self.hidden_dim*2, self.lexname_num)
        self.fc_r = torch.nn.Linear(self.hidden_dim*2, self.rootaffix_num)
        self.loss = torch.nn.CrossEntropyLoss()
        self.relu = torch.nn.ReLU()

        
    def forward(self, operation, x=None, w=None, ws=None, wl=None, wr=None, msk_s=None, msk_l=None, msk_r=None, mode=None):
        # x: T(bat, max_word_num)
        # w: T(bat)
        # x_embedding: T(bat, max_word_num, embed_dim)
        x_embedding = self.embedding(x)
        x_embedding = self.embedding_dropout(x_embedding)
        # mask: T(bat, max_word_num)
        mask = torch.gt(x, 0).to(torch.int64)
        # x_len: T(bat)
        x_len = torch.sum(mask, dim=1)
        # h: T(bat, max_word_num, hid*2)
        # ht: T(num_layers*2, bat, hid) float32
        h, (ht, _) = self.encoder(x_embedding, x_len)
        # ht: T(bat, hid*2)
        ht = torch.transpose(ht[ht.shape[0] - 2:, :, :], 0, 1).contiguous().view(x_len.shape[0], self.hidden_dim*2)
        # alpha: T(bat, max_word_num, 1)
        alpha = (h.bmm(ht.unsqueeze(2)))
        # mask_3: T(bat, max_word_num, 1)
        mask_3 = mask.to(torch.float32).unsqueeze(2)

        ## word prediction
        # vd: T(bat, embed_dim)
        h_1 = torch.sum(h*alpha, 1)
        vd = self.fc(h_1) #+ torch.sum(self.embedding(x), 1)#+ torch.sum(x_embedding, 1) #ok
        #vd = self.fc(torch.sum(torch.cat([h, self.embedding(x)], 2)*alpha, 1)) #best
        score0 = vd.mm(self.embedding.weight[[range(self.class_num)]].t())
        score = score0
        if 's' in mode:
            ## sememe prediction
            # pos_score: T(bat, max_word_num, sememe_num)
            pos_score = self.fc_s(h)
            pos_score = pos_score*mask_3 + (-1e7)*(1-mask_3)
            # sem_score: T(bat, sememe_num)
            sem_score, _ = torch.max(pos_score, dim=1)
            #sem_score = torch.sum(pos_score * alpha, 1)
            # score: T(bat, class_num) = [bat, sememe_num] .mm [class_num, sememe_num].t()
            score_s = self.relu(sem_score.mm(ws.t()))
            #----------add mean sememe score to those who have no sememes
            # mean_sem_sc: T(bat)
            mean_sem_sc = torch.mean(score_s, 1)
            # msk: T(class_num)
            score_s = score_s + mean_sem_sc.unsqueeze(1).mm(msk_s.unsqueeze(0))
            #----------
            score = score + score_s
        if 'r' in mode:
            ## root-affix prediction
            pos_score_ = self.fc_r(h)
            pos_score_ = pos_score_*mask_3 + (-1e7)*(1-mask_3)
            ra_score, _ = torch.max(pos_score_, dim=1)
            score_r = self.relu(ra_score.mm(wr.t()))
            mean_ra_sc = torch.mean(score_r, 1)
            score_r = score_r + mean_ra_sc.unsqueeze(1).mm(msk_r.unsqueeze(0))
            score = score + score_r
        if 'l' in mode:
            ## lexname prediction
            lex_score = self.fc_l(h_1)
            score_l = self.relu(lex_score.mm(wl.t()))
            mean_lex_sc = torch.mean(score_l, 1)
            score_l = score_l + mean_lex_sc.unsqueeze(1).mm(msk_l.unsqueeze(0))
            score = score + score_l
        
        # fine-tune depended on the target word shouldn't exist in the definition.
        #score_res = score.clone().detach()
        mask1 = torch.lt(x, self.class_num).to(torch.int64)
        mask2 = torch.ones((score.shape[0], score.shape[1]), dtype=torch.float32, device=device)
        for i in range(x.shape[0]):
            mask2[i][x[i]*mask1[i]] = 0.
        score = score * mask2 + (-1e6)*(1-mask2)
        
        _, indices = torch.sort(score, descending=True)
        if operation == 'train':
            loss = self.loss(score, w)
            return loss, score, indices
        elif operation == 'test':
            return indices


evaluate.py

In [None]:

def evaluate(ground_truth, prediction):
    accu_1 = 0.
    accu_10 = 0.
    accu_100 = 0.
    length = len(ground_truth)
    for i in range(length):
        if ground_truth[i] in prediction[i][:100]:
            accu_100 += 1
            if ground_truth[i] in prediction[i][:10]:
                accu_10 += 1
                if ground_truth[i] == prediction[i][0]:
                    accu_1 += 1
    return accu_1/length*100, accu_10/length*100, accu_100/length*100

def evaluate_test(ground_truth, prediction):
    accu_1 = 0.
    accu_10 = 0.
    accu_100 = 0.
    length = len(ground_truth)
    pred_rank = []
    for i in range(length):
        try:
            pred_rank.append(prediction[i][:].index(ground_truth[i]))
        except:
            pred_rank.append(1000)
        if ground_truth[i] in prediction[i][:100]:
            accu_100 += 1
            if ground_truth[i] in prediction[i][:10]:
                accu_10 += 1
                if ground_truth[i] == prediction[i][0]:
                    accu_1 += 1
    return accu_1/length*100, accu_10/length*100, accu_100/length*100, np.median(pred_rank), np.sqrt(np.var(pred_rank))

# '''
# def evaluate_MAP(ground_truth, prediction):
#     index = 1
#     correct = 0
#     point = 0
#     for predicted_POS in prediction:
#         if predicted_POS in ground_truth:
#             correct += 1
#             point += (correct / index)
#         index += 1
#     point /= len(ground_truth)
#     return point*100.
# 
# import numpy as np    
# def evaluate1(ground_truth, prediction):
#     length = len(ground_truth)
#     ref = np.array(ground_truth)[:, np.newaxis]
#     _, c = np.where(np.array(prediction)==ref)
#     accu_1 = np.sum(c==0)
#     accu_10 = np.sum(c<10)
#     accu_100 = np.sum(c<100)
#     return accu_1/length*100, accu_10/length*100, accu_100/length*100
# '''

Something else?

In [None]:
# TBD

main.py

In [None]:
from typing import NamedTuple
from torch import Tensor

class Counts(NamedTuple):
    
    train: int
    valid: int
    test: int

class Split(NamedTuple):
    
    train_dataset: MyDataset
    valid_dataset: MyDataset
    test_dataset: MyDataset
    
    @staticmethod
    def from_loaded(loaded_data: LoadedData):
        test_dataset = MyDataset(loaded_data.indexes.data_test_500_seen_idx + loaded_data.indexes.data_test_500_unseen_idx + loaded_data.indexes.data_desc_c_idx)
        valid_dataset = MyDataset(loaded_data.indexes.data_dev_idx)
        train_dataset = MyDataset(loaded_data.indexes.data_train_idx + loaded_data.indexes.data_defi_c_idx)
        return Split(train_dataset, valid_dataset, test_dataset)


class Loaders(NamedTuple):
    
    train_dataloader: torch.utils.data.DataLoader
    valid_dataloader: torch.utils.data.DataLoader
    test_dataloader: torch.utils.data.DataLoader
    counts: Counts
    
    @staticmethod
    def from_split(datasets: Split, batch_size: int) -> 'Loaders':
        train_dataloader = torch.utils.data.DataLoader(datasets.train_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
        valid_dataloader = torch.utils.data.DataLoader(datasets.valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
        test_dataloader = torch.utils.data.DataLoader(datasets.test_dataset, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn)
        return Loaders(train_dataloader, valid_dataloader, test_dataloader, Counts(len(datasets.train_dataset), len(datasets.valid_dataset), len(datasets.test_dataset)))

class Preprocessed(NamedTuple):
    
    wd_sems: Tensor
    wd_lex: Tensor
    wd_ra: Tensor
    mask_s: Tensor
    mask_l: Tensor
    mask_r: Tensor
    
    def to(self, pt_device: str):
        d = self._asdict()
        for k, value in d.items():
            value: Tensor
            value = value.to(pt_device)
            d[k] = value
        return Preprocessed(**d)
    
    def describe(self) -> dict[str, torch.Size]:
        return dict((k, v.shape) for k, v in self._asdict().items())


def prepare_data(loaded_data: LoadedData) -> Preprocessed:
    # word2index, index2word, word2vec, index2each, label_size_each, data_idx_each = load_data(frequency)
    
    (label_size, label_lexname_size, label_rootaffix_size, label_sememe_size) = loaded_data.label
    (data_train_idx, data_dev_idx, data_test_500_seen_idx, data_test_500_unseen_idx, data_defi_c_idx, data_desc_c_idx) = loaded_data.indexes
    (index2sememe, index2lexname, index2rootaffix) = loaded_data.core_mappings
    
    sp = Split.from_loaded(loaded_data)
    
    print('Train dataset: ', len(sp.train_dataset))
    print('Valid dataset: ', len(sp.valid_dataset))
    print('Test dataset: ', len(sp.test_dataset))
    data_all_idx = data_train_idx + data_dev_idx + data_test_500_seen_idx + data_test_500_unseen_idx + data_defi_c_idx
    
    sememe_num = len(index2sememe)
    wd2sem = word2feature(data_all_idx, label_size, sememe_num, 'sememes') # label_size, not len(word2index). we only use target_words' feature
    wd_sems = label_multihot(wd2sem, sememe_num)
    wd_sems = torch.from_numpy(np.array(wd_sems)).to(device) #torch.from_numpy(np.array(wd_sems[:label_size])).to(device)
    lexname_num = len(index2lexname)
    wd2lex = word2feature(data_all_idx, label_size, lexname_num, 'lexnames') 
    wd_lex = label_multihot(wd2lex, lexname_num)
    wd_lex = torch.from_numpy(np.array(wd_lex)).to(device)
    rootaffix_num = len(index2rootaffix)
    wd2ra = word2feature(data_all_idx, label_size, rootaffix_num, 'root_affix') 
    wd_ra = label_multihot(wd2ra, rootaffix_num)
    wd_ra = torch.from_numpy(np.array(wd_ra)).to(device)
    mask_s = mask_noFeature(label_size, wd2sem, sememe_num)
    mask_l = mask_noFeature(label_size, wd2lex, lexname_num)
    mask_r = mask_noFeature(label_size, wd2ra, rootaffix_num)
    return Preprocessed(
        wd_sems=wd_sems,
        wd_lex=wd_lex,
        wd_ra=wd_ra,
        mask_s=mask_s,
        mask_l=mask_l,
        mask_r=mask_r,
    )


In [None]:
def build_encoder(loaded: LoadedData):
    sememe_num = len(loaded.core_mappings.index2sememe)
    lexname_num = len(loaded.core_mappings.index2lexname)
    rootaffix_num = len(loaded.core_mappings.index2rootaffix)
    model = Encoder(
        vocab_size=len(loaded.word_mappings.word2index), 
        embed_dim=loaded.word_mappings.word2vec.shape[1], 
        hidden_dim=300, 
        layers=1, 
        class_num=loaded.label.size, 
        sememe_num=sememe_num, 
        lexname_num=lexname_num, 
        rootaffix_num=rootaffix_num,
    )
    model.embedding.weight.data = torch.from_numpy(loaded.word_mappings.word2vec)
    return model

In [None]:
from pathlib import Path
from typing import TypeVar

T = TypeVar("T")
CACHE_DISABLED = False
CACHE_DIR = "./data/cache"

def acquire(cache_filename: str, loader: Callable[[], T]) -> T:
    if CACHE_DISABLED:
        return loader()
    cache_file = Path(CACHE_DIR) / cache_filename
    if cache_file.is_file():
        item = torch.load(str(cache_file), map_location=device)
        print("loaded from cache", cache_file)
        return item
    item = loader()
    cache_file.parent.mkdir(exist_ok=True, parents=True)
    torch.save(item, str(cache_file))
    print("cached item to", cache_file.as_posix(), cache_file.stat().st_size, "bytes")
    return item

In [None]:
FREQUENCY = 20
LOADED_DATA = acquire("loaded_data.pth", lambda: load_data(frequency=FREQUENCY))

In [None]:
PREPROCESSED_DATA = acquire("preprocessed.pth", lambda: prepare_data(LOADED_DATA))
print("\n".join(f"{k} = {v}" for k, v in PREPROCESSED_DATA.describe().items()))
# PREPROCESSED_DATA.to(device)

In [None]:
def check(loaded_data: LoadedData, pp: Preprocessed):
    datasets = Split.from_loaded(loaded_data)
    ld = Loaders.from_split(datasets, batch_size=128)
    words_t, definition_words_t = next(iter(ld.train_dataloader))
    print("words_t", words_t.shape, getattr(words_t, "device", None))
    print("definition_words_t", definition_words_t.shape, getattr(definition_words_t, "device", None))
    model = build_encoder(loaded_data)
    model.to(device)
    loss, _, indices = model('train', x=definition_words_t, w=words_t, ws=pp.wd_sems, wl=pp.wd_lex, wr=pp.wd_ra, msk_s=pp.mask_s, msk_l=pp.mask_l, msk_r=pp.mask_r, mode='b')
    print("loss", loss.shape)
    print("indices", indices.shape)

check(LOADED_DATA, PREPROCESSED_DATA)

In [None]:
import torch, os, random, json
from tqdm import tqdm
from datetime import datetime
RUN = False
FORCE_GC = False

if FORCE_GC:
    import gc
else:
    class Noop:
        def collect(self):
            pass
    gc = Noop()


class Trajectory(NamedTuple):
    
    loss: list[float]
    accuracy: list[dict[int, float]]
    
    @staticmethod
    def create() -> 'Trajectory':
        return Trajectory([], [])


class History(NamedTuple):
    
    train: Trajectory
    valid: Trajectory
    
    @staticmethod
    def create() -> 'History':
        return History(Trajectory.create(), Trajectory.create())
    
    def to_dict(self):
        return {
            "train": self.train._asdict(),
            "valid": self.valid._asdict(),
        }
    
    @staticmethod
    def from_dict(d: dict[str, Any]) -> 'History':
        return History(**dict((k, Trajectory(**v)) for k, v in d.items()))
        


class Restoration(NamedTuple):
    
    state_dict: dict[str, Any]
    epoch: int
    history: History
    


class Checkpointer:
    
    def __init__(self, 
                 checkpoint_dir: Path, 
                 history: History, 
                 index2word: np.ndarray, 
                 mode: str,
                 uploader: Uploader):
        self.checkpoint_dir = checkpoint_dir
        self.history = history
        self.mode = mode
        self.index2word = index2word
        self.uploader = uploader
    
    def save_results(self, label_list: list[int], pred_list: list[list[int]], epoch: int):
        saved_files = []
        for infix, content in {
            "label": (self.index2word[label_list]).tolist(),
            "pred": (self.index2word[np.array(pred_list)]).tolist()
        }.items():
            list_file = self.checkpoint_dir / f"checkpoint-epoch{epoch:02d}-{self.mode}_{infix}_list.json"
            with open(list_file, "w") as ofile:
                json.dump(content, ofile, indent=2)
            saved_files.append(list_file)
        return saved_files
    
    @staticmethod
    def restore(checkpoint_file: Path, pt_device: str = None) -> 'Restoration':
        checkpoint = torch.load(str(checkpoint_file), map_location=pt_device)
        history = History.from_dict(checkpoint["history"])
        checkpoint["history"] = history
        checkpoint = dict((k, v) for k, v in checkpoint.items() if k in Restoration._fields)
        return Restoration(**checkpoint)
    
    def checkpoint(self, model: torch.nn.Module, epoch: int) -> Path:
        saved_files = {}
        epoch_infix = f"epoch{epoch:02d}"
        checkpoint_file = self.checkpoint_dir / f"model-{epoch_infix}.pt"
        checkpoint_file.parent.mkdir(exist_ok=True, parents=True)
        checkpoint = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "history": self.history.to_dict(),
        }
        torch.save(checkpoint, str(checkpoint_file))
        saved_files["checkpoint"] = checkpoint_file
        for conservation_group, file in saved_files.items():
            self.uploader.upload_file(file, file.name, suppress_error=True, conserve=conservation_group)
        return checkpoint_file

def timestamp() -> str:
    return datetime.now().strftime("%Y%m%d-%H%M")


def main(loaded: LoadedData, ld: Loaders, pp: Preprocessed, epoch_num: int, quiet: bool, MODE: str = 'b'):
    index2word = np.array(loaded.word_mappings.index2word)
    model = build_encoder(loaded)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam
    best_valid_accu = 0
    history = History.create()
    checkpointer = Checkpointer(Path("./checkpoints") / timestamp(), history, index2word, mode=MODE, uploader=UPLOADER)
    checkpoint_file = None
    DEF_UPDATE = True
    for epoch in range(epoch_num):
        print('epoch: ', epoch)
        model.train()
        train_loss = 0
        label_list = list()
        pred_list = list()
            
        for words_t, definition_words_t in tqdm(ld.train_dataloader, disable=quiet):
            optimizer.zero_grad()
            loss, _, indices = model('train', x=definition_words_t, w=words_t, ws=pp.wd_sems, wl=pp.wd_lex, wr=pp.wd_ra, msk_s=pp.mask_s, msk_l=pp.mask_l, msk_r=pp.mask_r, mode=MODE)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            predicted = indices[:, :100].detach().cpu().numpy().tolist()
            train_loss += loss.item()
            label_list.extend(words_t.detach().cpu().numpy())
            pred_list.extend(predicted)
        train_accu_1, train_accu_10, train_accu_100 = evaluate(label_list, pred_list)
        del label_list
        del pred_list
        gc.collect()
        train_loss_mean = train_loss/ld.counts.train
        history.train.loss.append(train_loss_mean)
        history.train.accuracy.append({1: train_accu_1, 10: train_accu_10, 100: train_accu_100})
        print('train_loss: ', train_loss_mean)
        print('train_accu(1/10/100): %.2f %.2F %.2f'%(train_accu_1, train_accu_10, train_accu_100))
        model.eval()
        with torch.no_grad():
            valid_loss = 0
            label_list = []
            pred_list = []
            for words_t, definition_words_t in tqdm(ld.valid_dataloader, disable=quiet):
                loss, _, indices = model('train', x=definition_words_t, w=words_t, ws=pp.wd_sems, wl=pp.wd_lex, wr=pp.wd_ra, msk_s=pp.mask_s, msk_l=pp.mask_l, msk_r=pp.mask_r, mode=MODE)
                predicted = indices[:, :100].detach().cpu().numpy().tolist()
                valid_loss += loss.item()
                label_list.extend(words_t.detach().cpu().numpy())
                pred_list.extend(predicted)
            valid_accu_1, valid_accu_10, valid_accu_100 = evaluate(label_list, pred_list)
            valid_loss_mean = valid_loss/ld.counts.valid
            history.valid.loss.append(valid_loss_mean)
            history.valid.accuracy.append({1: valid_accu_1, 10: valid_accu_10, 100: valid_accu_100})
            print('valid_loss: ', valid_loss_mean)
            print('valid_accu(1/10/100): %.2f %.2F %.2f'%(valid_accu_1, valid_accu_10, valid_accu_100))
            
            if valid_accu_10 > best_valid_accu:
                best_valid_accu = valid_accu_10
                print('-----best_valid_accu-----')
                checkpoint_file = checkpointer.checkpoint(model, epoch)
            del label_list
            del pred_list
            gc.collect()
    return checkpoint_file
            
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def example():
    seed = 543624
    setup_seed(seed)
    datasets = Split.from_loaded(LOADED_DATA)
    batch_size = 128
    loaders = Loaders.from_split(datasets, batch_size=batch_size)
    print(f'DataLoaders prepared. Batch_size {batch_size}')
    checkpoint_file = main(
        loaded=LOADED_DATA,
        ld=loaders,
        pp=PREPROCESSED_DATA,
        epoch_num=25,
        quiet=False,
        MODE='b',
    )
    return checkpoint_file

CHECKPOINT_FILE = None
if RUN:
    CHECKPOINT_FILE = example()


In [None]:
%matplotlib inline
import numpy as np
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from typing import Tuple

class CurveData(NamedTuple):
    
    subject: str
    y_label: str
    train: Sequence[float]
    valid: Sequence[float]
    y_bounds: Optional[Tuple[float, float]] = None
    


def plot_epochs_curves(history: History, title: Optional[str] = None, rank_threshold: int = 10):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    fig: Figure
    if title:
        fig.suptitle(title)
    loss = CurveData("Loss", "Cross-Entropy Loss", history.train.loss, history.valid.loss, )
    train_acc = np.array([acc[rank_threshold] for acc in history.train.accuracy])
    valid_acc = np.array([acc[rank_threshold] for acc in history.valid.accuracy])
    acc = CurveData(f"Rank-{rank_threshold} Accuracy", "Correct (%)", train_acc, valid_acc, (0.0, 100.0))
    for ax, curve in zip(axes, [loss, acc]):
        curve: CurveData
        ax: Axes
        ax.set_title(curve.subject)
        ax.set_xlabel("Epochs")
        ax.set_ylabel(curve.y_label)
        train_values, val_values = curve.train, curve.valid
        train_values, val_values = np.array(train_values), np.array(val_values)
        epochs = list(range(max(len(train_values), len(val_values))))
        ax.plot(epochs, train_values, label=f"Train {curve.subject}")
        ax.plot(epochs, val_values, label=f"Validation {curve.subject}")
        ax.legend()
        if curve.y_bounds is not None:
            ax.set_ylim(*curve.y_bounds)
    plt.show()

def show_checkpoint_plots(checkpoint_file: Path):
    if not checkpoint_file or not checkpoint_file.exists():
        print("checkpoint file not defined or not found:", checkpoint_file)
    restored = Checkpointer.restore(checkpoint_file, pt_device="cpu")
    plot_epochs_curves(restored.history)

# CHECKPOINT_FILE = Path("./checkpoints/20240502-2111/model-epoch05.pt")
show_checkpoint_plots(CHECKPOINT_FILE)

In [None]:
class FullStack(NamedTuple):
    
    model: torch.nn.Module
    loaded_data: LoadedData
    preprocessed: Preprocessed
    
    @staticmethod
    def instantiate(checkpoint_file: Path, pt_device: str = None) -> 'FullStack':
        checkpoint = Checkpointer.restore(checkpoint_file, pt_device=pt_device)
        try:
            loaded_data = LOADED_DATA
        except NameError:
            loaded_data = acquire("loaded_data.pth", lambda: load_data(frequency=FREQUENCY))
        try:
            pp = PREPROCESSED_DATA
        except NameError:
            pp = acquire("preprocessed.pth", lambda: prepare_data(loaded_data))
        model = build_encoder(loaded_data)
        model.load_state_dict(checkpoint.state_dict)
        model.to(pt_device)
        model.eval()
        return FullStack(model, loaded_data, pp)
            
        

In [None]:

def evaluate_test_set(checkpoint_file: Path, pt_device: str = None, quiet: bool = False, MODE: str = 'b'):
    try:
        stack = FullStack.instantiate(checkpoint_file, pt_device=pt_device)    
    except FileNotFoundError:
        print("checkpoint file does not exist:", checkpoint_file.as_posix())
        return
    split = Split.from_loaded(stack.loaded_data)
    dataloader = Loaders.from_split(split, batch_size=512).test_dataloader
    pp = stack.preprocessed
    model = stack.model
    label_list = []
    pred_list = []
    for words_t, definition_words_t in tqdm(dataloader, disable=quiet):
        indices = model('test', x=definition_words_t, w=words_t, ws=pp.wd_sems, wl=pp.wd_lex, wr=pp.wd_ra, msk_s=pp.mask_s, msk_l=pp.mask_l, msk_r=pp.mask_r, mode=MODE)
        predicted = indices[:, :1000].detach().cpu().numpy().tolist()
        label_list.extend(words_t.detach().cpu().numpy())
        pred_list.extend(predicted)
    test_accu_1, test_accu_10, test_accu_100, median, variance = evaluate_test(label_list, pred_list)
    print('test_accu(1/10/100): %.2f %.2F %.2f %.2f %.2f'%(test_accu_1, test_accu_10, test_accu_100, median, variance))

# CHECKPOINT_FILE = "./checkpoints/[...]"
evaluate_test(CHECKPOINT_FILE, device)


In [None]:
import tabulate

def user_demo(checkpoint_file: Path, pt_device: str = None, MODE: str = 'b'):
    try:
        stack = FullStack.instantiate(checkpoint_file, pt_device=pt_device)    
    except FileNotFoundError:
        print("checkpoint file does not exist:", checkpoint_file.as_posix())
        return
    loaded_data = stack.loaded_data
    pp = stack.preprocessed
    model = stack.model
    print("word mappings", loaded_data.word_mappings.describe())
    print("core mappings", loaded_data.core_mappings.describe())
    print("label", loaded_data.label)
    split = Split.from_loaded(loaded_data)
    dataloader = Loaders.from_split(split, batch_size=8).test_dataloader
    label_list = []
    pred_list = []
    words_t, definition_words_t = next(iter(dataloader))
    words_t: Tensor
    print(f'words_t: {words_t.shape} {words_t.dtype}')
    print(f'definition_words_t: {definition_words_t.shape}')
    indices = model('test', x=definition_words_t, w=words_t, ws=pp.wd_sems, wl=pp.wd_lex, wr=pp.wd_ra, msk_s=pp.mask_s, msk_l=pp.mask_l, msk_r=pp.mask_r, mode=MODE)
    predicted: list[list[int]] = indices[:, :1000].detach().cpu().numpy().tolist()
    label_list.extend(words_t.detach().cpu().numpy())
    def _describe(seq):
        a = np.array(seq)
        return f"{a.shape} {a.dtype}"
    print("label list", _describe(label_list))
    pred_list.extend(predicted)
    print("predicted", _describe(predicted))
    print("pred_list", _describe(pred_list))
    
    table = []
    for index, (word_i, definition_i) in enumerate(zip(words_t, definition_words_t)):
        definition_cat = loaded_data.word_mappings.expand(definition_i)
        prediction_i = predicted[index][0]
        prediction_word = loaded_data.word_mappings.index2word[prediction_i]
        try:
            rank = predicted[index].index(word_i) + 1            
        except ValueError:
            rank = "1000+"
        table.append([prediction_word, loaded_data.word_mappings.index2word[word_i], rank, definition_cat])
    print()
    print()
    print(tabulate.tabulate(table, headers=["prediction", "actual", "actual\nrank", "clue"]))
        


# CHECKPOINT_FILE = "./checkpoints/[...]"
user_demo(CHECKPOINT_FILE, device)