In [None]:
import argparse
import json
import os
import random
from typing import Any, Dict, List, Union

import pytorch_lightning as pl
import torch
from quaterion_models.encoders import Encoder
from quaterion_models.heads import EncoderHead, GatedHead, SkipConnectionHead
from quaterion_models.types import CollateFnType
from torch.utils.data import Dataset

from quaterion import Quaterion, TrainableModel
from quaterion.loss import TripletLoss
from quaterion.dataset.similarity_data_loader import (
    GroupSimilarityDataLoader,
    SimilarityGroupSample,
)
#import SkipConnectionHead


import torch
import torch.nn as nn

from quaterion.loss import SimilarityLoss, SoftmaxLoss

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling

import pandas as pd


In [None]:
with open("./labeled_data_train.json", "r") as f:
    data = json.loads(f.read())

In [None]:
from quaterion.dataset.similarity_data_loader import (
    GroupSimilarityDataLoader,
    SimilarityGroupSample,
)
import numpy as np
class JsonDataset(Dataset):
    def __init__(self, path: str):
        super().__init__()
        with open(path, "r") as f:
            self.data = json.loads(f.read())
            # self.data = [json.loads(line) for line in f.readlines()]

    def __getitem__(self, index: int) -> SimilarityGroupSample:
        item = self.data[index]
        grp = np.random.randint(0,5)
        return SimilarityGroupSample(obj=item, group=grp)

    def __len__(self) -> int:
        return len(self.data)

train_dataloader = GroupSimilarityDataLoader(JsonDataset('./labeled_data_train.json'), batch_size=128)
val_dataloader = GroupSimilarityDataLoader(JsonDataset('./labeled_data_val.json'), batch_size=128)

In [None]:
ds = JsonDataset('./labeled_data_train.json')

In [None]:
ds.__getitem__(1)

In [None]:
class DescriptionEncoder(Encoder):
    def __init__(self, transformer: Transformer, pooling: Pooling):
        super().__init__()
        self.transformer = transformer
        self.pooling = pooling
        self.encoder = nn.Sequential(self.transformer, self.pooling)

    @property
    def trainable(self) -> bool:
        return False # Disable weights update for this encoder

    @property
    def embedding_size(self) -> int:
        return self.transformer.get_word_embedding_dimension()

    def forward(self, batch) -> torch.Tensor:
        return self.encoder(batch)["sentence_embedding"]

    def collate_descriptions(self, batch: List[Any]) -> torch.Tensor:
        descriptions = [record['Title'] + '. ' + record['Description'] for record in batch]
        return self.transformer.tokenize(descriptions)

    def get_collate_fn(self) -> CollateFnType:
        return self.collate_descriptions

    def save(self, output_path: str):
        self.transformer.save(os.path.join(output_path, 'transformer'))
        self.pooling.save(os.path.join(output_path, 'pooling'))

    @classmethod
    def load(cls, input_path: str) -> Encoder:
        transformer = Transformer.load(os.path.join(input_path, 'transformer'))
        pooling = Pooling.load(os.path.join(input_path, 'pooling'))
        return cls(transformer=transformer, pooling=pooling)

In [None]:
class Model(TrainableModel):
    def __init__(self, lr: float):
        self._lr = lr
        super().__init__()

    def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
        pre_trained = SentenceTransformer("all-MiniLM-L6-v2")
        transformer, pooling = pre_trained[0], pre_trained[1]
        return DescriptionEncoder(transformer, pooling)

    def configure_head(self, input_embedding_size) -> EncoderHead:
        return SkipConnectionHead(input_embedding_size)

    def configure_loss(self) -> SimilarityLoss:
        return TripletLoss()

    def configure_optimizers(self):
        return torch.optim.Adam( self.model.parameters(), lr=self._lr)

In [None]:
model = Model(lr=0.01)

Quaterion.fit(
    trainable_model=model,
    trainer=None, # Use default trainer
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader
)