## Ранжирование треков

### - загрузка и распаковка датасетов

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install jsonlines
!pip install tqdm

Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)
Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-4.0.0


In [None]:
import uuid

# temp_path = "/content/sample_data"
# test_path ="/content/drive/MyDrive/yandex_ml_olimp/data/test.zip"
# train_path ="/content/drive/MyDrive/yandex_ml_olimp/data/train.zip"
drive_model_path ="/content/drive/MyDrive/yandex_ml_olimp/models"
test_out_path ="/content/drive/MyDrive/yandex_ml_olimp/outputs_test"
val_out_path  ="/content/drive/MyDrive/yandex_ml_olimp/outputs_val"
splits_path = "/content/drive/MyDrive/yandex_ml_olimp/splits"
tsv_path = "/content/drive/MyDrive/yandex_ml_olimp/cliques2versions.tsv"
np_data_path = "/content/drive/MyDrive/yandex_ml_olimp/splits/train_cliques.npy"
submission_name = f"submission_{str(uuid.uuid4())}.txt"

In [None]:
import shutil
import os
from tqdm import tqdm
from zipfile import ZipFile

temp_dir = os.path.join(temp_path, "data")
os.makedirs(temp_dir, exist_ok=True)

temp_splits_dir = os.path.join(temp_dir, "splits")
os.makedirs(temp_splits_dir, exist_ok=True)

names = ["test_ids.npy", "train_cliques.npy", "val_cliques.npy"]
for name in names:
    src = os.path.join(splits_path, name)
    dest = os.path.join(temp_splits_dir, name)
    shutil.copy(src, dest)
shutil.copy(tsv_path, temp_dir)

os.makedirs("outputs_val", exist_ok=True)

In [None]:
with ZipFile(test_path, 'r') as zf:
    for mem in tqdm(zf.infolist(), desc="Extracting..."):
        try:
            zf.extract(mem, temp_dir)
        except Exception as e:
            print("Extracting error:", e)
            pass

Extracting...: 100%|██████████| 56281/56281 [00:27<00:00, 2010.04it/s]


In [None]:
train_path ="/content/drive/MyDrive/yandex_ml_olimp/data/train.zip"

with ZipFile(train_path, 'r') as zf:
    for mem in tqdm(zf.infolist(), desc="Extracting..."):
        try:
            zf.extract(mem, temp_dir)
        except Exception as e:
            print("Extracting error:", e)
            pass

Extracting...: 100%|██████████| 316050/316050 [02:42<00:00, 1939.13it/s]


### - Основные импорты

In [None]:
import glob
import json
import os
import re
from copy import deepcopy
from time import time
from typing import Dict, List, Literal, Tuple

import jsonlines
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from sklearn.metrics import pairwise_distances, pairwise_distances_chunked

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset

### - Дополнительные структуры

In [None]:
from dataclasses import dataclass
from yaml import FullLoader, load, safe_load

@dataclass
class bcolors:
    OKGREEN: str = "\033[92m"
    WARNING: str = "\033[93m"
    FAIL: str = "\033[91m"
    ENDC: str = "\033[0m"


def load_config(config_path: str) -> Dict:
    with open(config_path) as file:
        config = safe_load(file)

    if config["device"] == "gpu":
        config["device"] = "cuda:0"

    return config

In [None]:
from typing import TypedDict

class ValDict(TypedDict):
    anchor_id: int
    f_t: torch.Tensor
    f_c: torch.Tensor


class BatchDict(TypedDict):
    anchor_id: int
    anchor: torch.Tensor
    anchor_label: torch.Tensor
    positive_id: int
    positive: torch.Tensor
    negative_id: int
    negative: torch.Tensor


class Postfix(TypedDict):
    Epoch: int
    train_loss: float
    train_loss_step: float
    train_cls_loss: float
    train_cls_loss_step: float
    train_triplet_loss: float
    train_triplet_loss_step: float
    val_loss: float
    mr1: float
    mAP: float


class TestResults(TypedDict):
    test_mr1: float
    test_mAP: float

### - Вспомогательные функции 

In [None]:
def get_cls_weights(df:pd.DataFrame)->np.ndarray:
    """
    Посчитать веса классов для кросс-энтропии
    :param df:
    :return:
    """
    df2 = df["versions"].apply(pd.Series)
    df2.index=df.set_index("clique").index
    df2=df2.stack().reset_index("clique")
    df2=df2.rename(columns={0:"track_id"}).astype("int32")
    df2["count"]=df2.groupby(df2["clique"]).transform("count")

    df2=df2.drop(columns=["track_id"])
    df2=df2.drop_duplicates(keep="first")
    weights = {k:v for k, v in zip(df2["clique"], df2["count"])}
    w=list(sorted(list(weights.items()), key=lambda el: el[0]))
    w=np.array([el[0] for el in w])
    w=(w-w.min())/w.max()
    return w

def read_df(data_path="", np_data_path=""):
    '''Считывание датасета'''
    cliques_subset = np.load(
        np_data_path
    )

    versions = pd.read_csv(
        data_path,
        sep="\t",
        converters={"versions": eval},
    )

    versions = versions[
        versions["clique"].isin(set(cliques_subset))
    ]
    mapping = {}
    for k, clique in enumerate(sorted(cliques_subset)):
        mapping[clique] = k
    versions["clique"] = versions["clique"].map(lambda x: mapping[x])
    versions.set_index("clique", inplace=False)

    return versions

### - Датасет и даталоадер

In [None]:
class CoverDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        file_ext: str,
        dataset_path: str,
        data_split: Literal["train", "val", "test"],
        debug: bool,
        max_len: int,
    ) -> None:
        super().__init__()
        self.data_path = data_path
        self.file_ext = file_ext
        self.dataset_path = dataset_path
        self.data_split = data_split
        self.debug = debug
        self.max_len = max_len
        self._load_data()
        self.rnd_indices = np.random.permutation(len(self.track_ids))
        self.current_index = 0

    def __len__(self) -> int:
        return len(self.track_ids)

    def __getitem__(self, index: int) -> BatchDict:
        track_id = self.track_ids[index]
        anchor_cqt = self._load_cqt(track_id)

        if self.data_split == "train":
            clique_id = self.version2clique.loc[track_id, "clique"]
            pos_id, neg_id = self._triplet_sampling(track_id, clique_id)
            positive_cqt = self._load_cqt(pos_id)
            negative_cqt = self._load_cqt(neg_id)
        else:
            clique_id = -1
            pos_id = torch.empty(0)
            positive_cqt = torch.empty(0)
            neg_id = torch.empty(0)
            negative_cqt = torch.empty(0)
        return dict(
            anchor_id=track_id,
            anchor=anchor_cqt,
            anchor_label=torch.tensor(clique_id, dtype=torch.float),
            positive_id=pos_id,
            positive=positive_cqt,
            negative_id=neg_id,
            negative=negative_cqt,
        )

    def _make_file_path(self, track_id, file_ext):
        a = track_id % 10
        b = track_id // 10 % 10
        c = track_id // 100 % 10
        return os.path.join(str(c), str(b), str(a), f"{track_id}.{file_ext}")

    def _triplet_sampling(self, track_id: int, clique_id: int) -> Tuple[int, int]:
        versions = self.versions.loc[clique_id, "versions"]
        pos_list = np.setdiff1d(versions, track_id)
        pos_id = np.random.choice(pos_list, 1)[0]
        if self.current_index >= len(self.rnd_indices):
            self.current_index = 0
            self.rnd_indices = np.random.permutation(len(self.track_ids))
        neg_id = self.track_ids[self.rnd_indices[self.current_index]]
        self.current_index += 1
        while neg_id in versions:
            if self.current_index >= len(self.rnd_indices):
                self.current_index = 0
                self.rnd_indices = np.random.permutation(len(self.track_ids))
            neg_id = self.track_ids[self.rnd_indices[self.current_index]]
            self.current_index += 1
        return (pos_id, neg_id)

    def _load_data(self) -> None:
        if self.data_split in ["train", "val"]:
            cliques_subset = np.load(
                os.path.join(
                    self.data_path, "splits", "{}_cliques.npy".format(self.data_split)
                )
            )

            self.versions = pd.read_csv(
                os.path.join(self.data_path, "cliques2versions.tsv"),
                sep="\t",
                converters={"versions": eval},
            )
            self.versions = self.versions[
                self.versions["clique"].isin(set(cliques_subset))
            ]
            mapping = {}
            for k, clique in enumerate(sorted(cliques_subset)):
                mapping[clique] = k
            self.versions["clique"] = self.versions["clique"].map(lambda x: mapping[x])
            self.versions.set_index("clique", inplace=True)
            self.version2clique = pd.DataFrame(
                [
                    {"version": version, "clique": clique}
                    for clique, row in self.versions.iterrows()
                    for version in row["versions"]
                ]
            ).set_index("version")
            self.track_ids = self.version2clique.index.to_list()

        else:
            self.track_ids = np.load(
                os.path.join(
                    self.data_path, "splits", "{}_ids.npy".format(self.data_split)
                )
            )

    def _load_cqt(self, track_id: str) -> torch.Tensor:
        filename = os.path.join(
            self.dataset_path, self._make_file_path(track_id, self.file_ext)
        )
        cqt_spectrogram = np.load(filename)
        return torch.from_numpy(cqt_spectrogram)


def cover_dataloader(
    data_path: str,
    file_ext: str,
    dataset_path: str,
    data_split: Literal["train", "val", "test"],
    debug: bool,
    max_len: int,
    batch_size: int,
    **config: Dict,
) -> DataLoader:
    return DataLoader(
        CoverDataset(
            data_path, file_ext, dataset_path, data_split, debug, max_len=max_len
        ),
        batch_size=batch_size if max_len > 0 else 1,
        num_workers=config["num_workers"],
        shuffle=config["shuffle"],
        drop_last=config["drop_last"],
    )

### - Еще вспомогательные функции

In [None]:
def reduce_func(D_chunk, start):
    top_size = 100
    nearest_items = np.argsort(D_chunk, axis=1)[:, : top_size + 1]
    return [(i, items[items != i]) for i, items in enumerate(nearest_items, start)]


def dataloader_factory(config: Dict, data_split: str) -> DataLoader:
    return cover_dataloader(
        data_path=config["data_path"],
        file_ext=config["file_extension"],
        # dataset_path=config[data_split]["dataset_path"],
        data_split=data_split,
        debug=config["debug"],
        max_len=50,
        **config[data_split],
    )


def calculate_ranking_metrics(
    embeddings: np.ndarray, cliques: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
    distances = pairwise_distances(embeddings)
    s_distances = np.argsort(distances, axis=1)
    cliques = np.array(cliques)
    query_cliques = cliques[s_distances[:, 0]]
    search_cliques = cliques[s_distances[:, 1:]]

    query_cliques = np.tile(query_cliques, (search_cliques.shape[-1], 1)).T
    mask = np.equal(search_cliques, query_cliques)

    ranks = 1.0 / (mask.argmax(axis=1) + 1.0)

    cumsum = np.cumsum(mask, axis=1)
    mask2 = mask * cumsum
    mask2 = mask2 / np.arange(1, mask2.shape[-1] + 1)
    average_precisions = np.sum(mask2, axis=1) / np.sum(mask, axis=1)

    return (ranks, average_precisions)


def dir_checker(output_dir: str) -> str:
    output_dir = re.sub(r"run-[0-9]+/*", "", output_dir)
    runs = glob.glob(os.path.join(output_dir, "run-*"))
    if runs != []:
        max_run = max(map(lambda x: int(x.split("-")[-1]), runs))
        run = max_run + 1
    else:
        run = 0
    outdir = os.path.join(output_dir, f"run-{run}")
    return outdir


def save_test_predictions(predictions: List, output_dir: str) -> None:
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, submission_name), "w") as foutput:
        for query_item, query_nearest in predictions:
            foutput.write(
                "{} {}\n".format(int(query_item), " ".join(map(str, list(map(int, query_nearest)))))
            )
    #TODO: добавить копирование на диск
    sub_path = f"/content/outputs_test/{submission_name}"
    shutil.copy(sub_path, test_out_path)


def save_predictions(outputs: Dict[str, np.ndarray], output_dir: str) -> None:
    os.makedirs(output_dir, exist_ok=True)
    for key in outputs:
        if "_ids" in key:
            with jsonlines.open(os.path.join(output_dir, f"{key}.jsonl"), "w") as f:
                if len(outputs[key][0]) == 4:
                    for clique, anchor, pos, neg in outputs[key]:
                        f.write(
                            {
                                "clique_id": clique,
                                "anchor_id": anchor,
                                "positive_id": pos,
                                "negative_id": neg,
                            }
                        )
                else:
                    for clique, anchor in outputs[key]:
                        f.write({"clique_id": clique, "anchor_id": anchor})
        else:
            np.save(os.path.join(output_dir, f"{key}.npy"), outputs[key])

### - Дополнительные структуры для ResNet

In [None]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(
            1.0 / p
        )

    def __repr__(self):
        return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={str(self.eps)})"


class IBN(nn.Module):
    r"""Instance-Batch Normalization layer from
    `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"
    <https://arxiv.org/pdf/1807.09441.pdf>`
    Args:
        planes (int): Number of channels for the input tensor
        ratio (float): Ratio of instance normalization in the IBN layer
    """

    def __init__(self, planes, ratio):
        super(IBN, self).__init__()
        self.half = int(planes * ratio)
        self.IN = nn.InstanceNorm2d(self.half, affine=True)
        self.BN = nn.BatchNorm2d(planes - self.half)

    def forward(self, x):
        split = torch.split(x, self.half, 1)
        out1 = self.IN(split[0].contiguous())
        out2 = self.BN(split[1].contiguous())
        out = torch.cat((out1, out2), 1)
        return out

### - Класс Bottleneck

In [None]:
class Bottleneck(nn.Module):

    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        last: bool = False,
        downsample=None,
        stride=1,
        bias: bool = True,
    ):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias
        )
        if not last:
            # Apply Instance normalization in first half channels (ratio=0.5)
            self.ibn = IBN(out_channels, ratio=0.5)
        else:
            self.ibn = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=bias,
        )
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(
            out_channels,
            out_channels * self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=bias,
        )
        self.batch_norm3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.downsample = downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        residual = x.clone()

        x = self.conv1(x)
        x = self.ibn(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.batch_norm3(x)
        x = self.relu(x)

        if self.downsample is not None:
            residual = self.downsample(residual)

        out = residual + x
        out = self.relu(out)

        return out

### - Resnet50
*Эту архитектуру создатели предложили в качестве baseline*

In [None]:
class Resnet50(nn.Module):
    def __init__(
        self,
        ResBlock: Bottleneck,
        emb_dim: int = 2048,
        num_channels: int = 1,
        num_classes: int = 8858,
        dropout=0.1,
        n_bins=84,
    ) -> None:

        super(Resnet50, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            in_channels=num_channels,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
        )
        self.batch_norm1 = nn.BatchNorm2d(num_features=64)
        self.relu = nn.ReLU()
        self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(ResBlock, blocks=3, planes=64, stride=1)
        self.layer2 = self._make_layer(ResBlock, blocks=4, planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, blocks=6, planes=256, stride=2)
        self.layer4 = self._make_layer(
            ResBlock, blocks=3, planes=512, stride=1, last=True
        )

        self.gem_pool = GeM()
        self.dropout = nn.Dropout(p=dropout)

        self.bn_fc = nn.BatchNorm1d(emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes, bias=False)
        nn.init.kaiming_normal_(self.fc.weight)

    def _make_layer(
        self,
        ResBlock: Bottleneck,
        blocks: int,
        planes: int,
        stride: int = 1,
        last: bool = False,
    ):
        downsample = None
        if stride != 1 or self.in_channels != planes * ResBlock.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    planes * ResBlock.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * ResBlock.expansion),
            )
        layers = []
        layers.append(
            ResBlock(
                in_channels=self.in_channels,
                out_channels=planes,
                stride=stride,
                downsample=downsample,
                last=last,
            )
        )
        self.in_channels = planes * ResBlock.expansion
        for _ in range(1, blocks):
            layers.append(
                ResBlock(in_channels=self.in_channels, out_channels=planes, last=last)
            )

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        # Unsqueeze to simulate 1-channel image
        x = self.conv1(x.unsqueeze(1))
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.max_pool1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        f_t = self.gem_pool(x)
        f_t = self.dropout(torch.flatten(f_t, start_dim=1))

        f_c = self.bn_fc(f_t)
        cls = self.fc(f_c)

        return dict(f_t=f_t, f_c=f_c, cls=cls)

### - Трансформер ViT
Попытка поменять модель на видеотрансформер. Выбила 0.2 скор

In [None]:
from torchvision.models import vit_b_16
from functools import partial
from torchvision.models.vision_transformer import Encoder
from typing import Callable

class MyViT(nn.Module):
    def __init__(
        self,
        pretrained_model,
        hidden_dim=768,
        emb_dim=2048,
        num_classes=39553):

        self.image_size = 64
        super(MyViT, self).__init__()

        # скелет
        self.pretrained = pretrained_model

        # голова
        self.pretrained.conv_proj = nn.Conv2d(
            1, hidden_dim, kernel_size=(8, 8), stride=(8, 8)
        )

        # основа
        patch_size = 8
        dropout =0.05
        attention_dropout = 0.05
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6)


        self.pretrained.patch_size = patch_size
        seq_length = (self.image_size // self.pretrained.patch_size) ** 2 + 1

        self.pretrained.encoder = Encoder(
            seq_length=seq_length,
            num_layers = 4,
            num_heads=12,
            hidden_dim=768,
            mlp_dim = 3072,
            dropout=dropout,
            attention_dropout=attention_dropout,
            norm_layer=norm_layer)


        self.pretrained.encoder.pos_embedding = nn.Parameter(
            torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)
        )

        self.pretrained.heads.head = nn.Linear(
            in_features=hidden_dim, out_features=emb_dim, bias=True
        )

        self.bn_fc = nn.BatchNorm1d(emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes, bias=False)
        nn.init.kaiming_normal_(self.fc.weight)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = nn.functional.interpolate(x, size=(self.image_size, self.image_size)).to(
            dtype=torch.float) # Было torch.bfloat16
        f_t = self.pretrained(x)
        # f_c = self.bn_fc(f_t)
        # cls = self.fc(f_c)
        cls = self.fc(f_е)
        return cls

In [None]:
from torchvision.models import vit_b_32
from functools import partial
from torchvision.models.vision_transformer import Encoder
from typing import Callable


class MyTransformer(nn.Module):
    def __init__(self,
        pretrained_model,
        hidden_dim=768,
        emb_dim=2048,
        num_classes=39553):

        self.image_size = 224
        super(MyTransformer, self).__init__()
        self.pretrained = pretrained_model
        self.pretrained.head = nn.Identity()
        self.new_head = nn.Sequential(
            nn.Linear(1000, emb_dim),
        )
        self.bn_fc = nn.BatchNorm1d(emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes, bias=False)
        nn.init.kaiming_normal_(self.fc.weight)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = torch.cat([x, x, x], dim =1)
        x = nn.functional.interpolate(x, size=(self.image_size, self.image_size)).to(
            dtype=torch.float)
        x = self.pretrained(x)
        f_t = self.new_head(x)
        f_c = self.bn_fc(f_t)
        cls = self.fc(f_c)
        return dict(f_t=f_t, f_c=f_c, cls=cls)

### - ResNet18
Попытка поменять модель на облегченный resnet

In [None]:
from torchvision.models import resnet18
import torch
import torch.nn as nn
class ResNet18(nn.Module):
    def __init__(self):

        super().__init__()
        hidden_dim = 512
        emb_dim = 512
        num_classes = 39535
        self.image_size = 64
        self.resnet = resnet18(weights=None)
        self.resnet.fc=nn.Linear(in_features=hidden_dim, out_features=emb_dim, bias=True)
        self.bn_fc = nn.BatchNorm1d(emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes, bias=False)

        nn.init.kaiming_normal_(self.fc.weight)
    def forward(self, x):
        x=x.unsqueeze(1)
        x=nn.functional.interpolate(x, size=(self.image_size, self.image_size))#.to(dtype=torch.float32, device=device)
        x=torch.cat([x, x, x], dim =1)
        f_t = self.resnet(x)
        f_c = self.bn_fc(f_t)
        cls=self.fc(f_c)
        return dict(f_t=f_t, f_c=f_c, cls=cls)

### - EarlyStopper

In [None]:
class EarlyStopper:
    def __init__(self, patience: int = 1, delta: int = 0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.max_validation_mAP = -np.inf

    def __call__(self, validation_mAP) -> bool:
        if validation_mAP > self.max_validation_mAP:
            self.max_validation_mAP = validation_mAP
            self.counter = 0
        elif validation_mAP <= (self.max_validation_mAP - self.delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

### - Основной пайплайн

In [None]:
class TrainModule:
    def __init__(self, config: Dict) -> None:
        self.config = config
        self.state = "initializing"
        self.best_model_path: str = None
        self.num_classes = self.config["train"][
            "num_classes"
        ]  # (39535,) Количество входных классов
        self.max_len = 50

        #self.model = Resnet50(
        #    Bottleneck,
        #    num_channels=self.config["num_channels"],
        #    num_classes=self.num_classes,
        #    dropout=self.config["train"]["dropout"]
        #)
        # self.model.to(self.config["device"])
        # os.environ['TORCH_HOME'] = 'models'
        # self.model = vit_b_32(pretrained=True, image_size=224)
        # self.model = MyTransformer(
        #      self.model,
        #      num_classes=self.config["train"]["num_classes"])
        # self.model = self.model.to(dtype=torch.float, device=config["device"])
        # print(self.model)

        # My VT 4 blocks
        #self.model = vit_b_16(pretrained=False, image_size=64)
        #self.model = MyViT(
        #     self.model,
        #     num_classes=self.config["train"]["num_classes"])
        #self.model = self.model.to(dtype=torch.float, device=config["device"])
        #print(self.model)
        self.model = ResNet18().to(device=config["device"])



        self.postfix: Postfix = {}

        # self.triplet_loss = nn.TripletMarginLoss(margin=config["train"]["triplet_margin"])
        self.triplet_loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y),
            margin=config["train"]["triplet_margin"],
        )
        #TODO: возможно, сюда стоит добавить взвешенные классы?
        np_weights = get_cls_weights(read_df(tsv_path, np_data_path))
        weight = torch.from_numpy(np_weights).to(dtype=torch.float32, device=self.config["device"])

        self.cls_loss = nn.CrossEntropyLoss(
            weight=weight,
            label_smoothing=config["train"]["smooth_factor"]
        )

        self.early_stop = EarlyStopper(patience=self.config["train"]["patience"])
        self.optimizer = self.configure_optimizers()
        #
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.95)

        if self.config["device"] != "cpu":
            # self.scaler = torch.cuda.amp.GradScaler(enabled=self.config["train"]["mixed_precision"])
            self.scaler = torch.cuda.amp.GradScaler(
                enabled=self.config["train"]["mixed_precision"]
            )

    def pipeline(self) -> None:
        self.config["val"]["output_dir"] = dir_checker(self.config["val"]["output_dir"])

        if self.config["train"]["model_ckpt"] is not None:
            self.model.load_state_dict(
                torch.load(self.config["train"]["model_ckpt"]), strict=False
            )
            print(f'Model loaded from checkpoint: {self.config["train"]["model_ckpt"]}')

        self.t_loader = dataloader_factory(config=self.config, data_split="train")
        self.v_loader = dataloader_factory(config=self.config, data_split="val")
        self.state = "running"

        self.pbar = trange(
            self.config["train"]["epochs"],
            disable=(not self.config["progress_bar"]),
            position=0,
            leave=True,
        )
        for epoch in self.pbar:
            if self.state in ["early_stopped", "interrupted", "finished"]:
                return

            self.postfix["Epoch"] = epoch
            self.pbar.set_postfix(self.postfix)

            try:
                # pass
                self.train_procedure()
            except KeyboardInterrupt:
                print("\nKeyboard Interrupt detected. Attempting gracefull shutdown...")
                self.state = "interrupted"
            except Exception as err:
                raise (err)

            '''
            if self.state == "interrupted":
                self.validation_procedure()
                self.pbar.set_postfix(
                    {
                        k: self.postfix[k]
                        for k in self.postfix.keys() & {"train_loss_step", "mr1", "mAP"}
                    }
                )
            '''

        self.state = "finished"

    def validate(self) -> None:
        self.v_loader = dataloader_factory(config=self.config, data_split="val")
        self.state = "running"
        self.validation_procedure()
        self.state = "finished"

    def test(self) -> None:
        self.test_loader = dataloader_factory(config=self.config, data_split="test")
        self.test_results: TestResults = {}

        if self.best_model_path is not None:
            self.model.load_state_dict(torch.load(self.best_model_path), strict=False)
            print(f"Best models loaded from checkpoint: {self.best_model_path}")
        elif self.config["test"]["model_ckpt"] is not None:
            self.model.load_state_dict(
                torch.load(self.config["test"]["model_ckpt"], map_location=torch.device(self.config["device"])), strict=False
            )
            print(f'Model loaded from checkpoint: {self.config["test"]["model_ckpt"]}')
        elif self.state == "initializing":
            print("Warning: Testing with random weights")

        self.model=self.model.to(dtype=torch.float)

        self.state = "running"
        self.test_procedure()
        self.state = "finished"

    def train_procedure(self) -> None:
        self.model.train()
        train_loss_list = []
        train_cls_loss_list = []
        train_triplet_loss_list = []
        self.max_len = self.t_loader.dataset.max_len
        for step, batch in tqdm(
            enumerate(self.t_loader),
            total=len(self.t_loader),
            disable=(not self.config["progress_bar"]),
            position=2,
            leave=False,
        ):
            train_step = self.training_step(batch)
            self.postfix["train_loss_step"] = float(
                f"{train_step['train_loss_step']:.3f}"
            )
            train_loss_list.append(train_step["train_loss_step"])
            self.postfix["train_cls_loss_step"] = float(
                f"{train_step['train_cls_loss']:.3f}"
            )
            train_cls_loss_list.append(train_step["train_cls_loss"])
            self.postfix["train_triplet_loss_step"] = float(
                f"{train_step['train_triplet_loss']:.3f}"
            )
            train_triplet_loss_list.append(train_step["train_triplet_loss"])
            self.pbar.set_postfix(
                {
                    k: self.postfix[k]
                    for k in self.postfix.keys() & {"train_loss_step", "mr1", "mAP"}
                }
            )
        train_loss = torch.tensor(train_loss_list)
        train_cls_loss = torch.tensor(train_cls_loss_list)
        train_triplet_loss = torch.tensor(train_triplet_loss_list)
        self.postfix["train_loss"] = train_loss.mean().item()
        self.postfix["train_cls_loss"] = train_cls_loss.mean().item()
        self.postfix["train_triplet_loss"] = train_triplet_loss.mean().item()

        self.best_model_path = os.path.join(drive_model_path, f"best-resnet18_{str(uuid.uuid4())}.pt")
        torch.save(deepcopy(self.model.state_dict()), self.best_model_path)

        #self.validation_procedure()
        #self.overfit_check()
        self.pbar.set_postfix(
            {
                k: self.postfix[k]
                for k in self.postfix.keys() & {"train_loss_step", "mr1", "mAP"}
            }
        )

    def training_step(self, batch: BatchDict) -> Dict[str, float]:
        with torch.autocast(
            device_type=self.config["device"].split(":")[0],
            enabled=self.config["train"]["mixed_precision"],
        ):
            anchor = self.model.forward(batch["anchor"].to(self.config["device"]))
            positive = self.model.forward(batch["positive"].to(self.config["device"]))
            negative = self.model.forward(batch["negative"].to(self.config["device"]))
            l1 = self.triplet_loss(anchor["f_t"], positive["f_t"], negative["f_t"])
            labels = nn.functional.one_hot(
                batch["anchor_label"].long(), num_classes=self.num_classes
            )
            l2 = 0.1* self.cls_loss(anchor["cls"], labels.float().to(self.config["device"]))
            loss = l1 + l2

        self.optimizer.zero_grad()
        if self.config["device"] != "cpu":
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            self.optimizer.step()
        # self.scheduler.step()

        return {
            "train_loss_step": loss.item(),
            "train_triplet_loss": l1.item(),
            "train_cls_loss": l2.item(),
        }

    def validation_procedure(self) -> None:
        print("start val proc")
        self.model.eval()
        embeddings: Dict[int, torch.Tensor] = {}
        for batch in tqdm(
            self.v_loader,
            disable=(not self.config["progress_bar"]),
            position=1,
            leave=False,
        ):
            val_dict = self.validation_step(batch)
            # print(val_dict)
            if val_dict["f_t"].ndim == 1:
                val_dict["f_c"] = val_dict["f_c"].unsqueeze(0)
                val_dict["f_t"] = val_dict["f_t"].unsqueeze(0)
            for anchor_id, triplet_embedding, embedding in zip(
                val_dict["anchor_id"], val_dict["f_t"], val_dict["f_c"]
            ):
                embeddings[anchor_id] = torch.stack([triplet_embedding, embedding])

        val_outputs = self.validation_epoch_end(embeddings)
        print(
            f"\n{' Validation Results ':=^50}\n"
            + "\n".join([f'"{key}": {value}' for key, value in self.postfix.items()])
            + f"\n{' End of Validation ':=^50}\n"
        )

        if self.config["val"]["save_val_outputs"]:
            val_outputs["val_embeddings"] = torch.stack(list(embeddings.values()))[
                :, 1
            ].numpy()
            save_predictions(val_outputs, output_dir=self.config["val"]["output_dir"])
        self.model.train()

    def validation_epoch_end(
        self, outputs: Dict[int, torch.Tensor]
    ) -> Dict[int, np.ndarray]:
        print("validation_epoch_end")
        # val_loss = torch.zeros(len(outputs))
        # pos_ids = []
        # neg_ids = []
        clique_ids = []
        for k, (anchor_id, embeddings) in enumerate(outputs.items()):
            # clique_id, pos_id, neg_id = self.v_loader.dataset._triplet_sampling(anchor_id)
            # val_loss[k] = self.triplet_loss(embeddings[0], outputs[pos_id][0], outputs[neg_id][0]).item()
            # pos_ids.append(pos_id)
            # neg_ids.append(neg_id)
            clique_id = self.v_loader.dataset.version2clique.loc[anchor_id, "clique"]
            clique_ids.append(clique_id)
        # anchor_ids = np.stack(list(outputs.keys()))
        preds = torch.stack(list(outputs.values()))[:, 1]
        # self.postfix["val_loss"] = val_loss.mean().item()
        rranks, average_precisions = calculate_ranking_metrics(
            embeddings=preds.float().numpy(), cliques=clique_ids
        )
        self.postfix["mrr"] = rranks.mean()
        self.postfix["mAP"] = average_precisions.mean()
        return {
            # "triplet_ids": np.stack(list(zip(clique_ids, anchor_ids, pos_ids, neg_ids))),
            "rranks": rranks,
            "average_precisions": average_precisions,
        }

    def validation_step(self, batch: BatchDict) -> ValDict:
        # print("start val step")
        anchor_id = batch["anchor_id"]
        features = self.model.forward(batch["anchor"].to(dtype=torch.float,
                                                         device=self.config["device"]))

        return {
            "anchor_id": anchor_id.float().numpy(),
            "f_t": features["f_t"].float().squeeze(0).detach().cpu(),
            "f_c": features["f_c"].float().squeeze(0).detach().cpu(),
        }

    def test_procedure(self) -> None:
        print("start test procedure")
        self.model.eval()
        embeddings: Dict[str, torch.Tensor] = {}
        trackids: List[int] = []
        embeddings: List[np.array] = []
        for batch in tqdm(self.test_loader, disable=(not self.config["progress_bar"])):
            test_dict = self.validation_step(batch)
            if test_dict["f_c"].ndim == 1:
                test_dict["f_c"] = test_dict["f_c"].unsqueeze(0)
            for anchor_id, embedding in zip(test_dict["anchor_id"], test_dict["f_c"]):
                trackids.append(anchor_id)
                embeddings.append(embedding.numpy())
        predictions = []
        for chunk_result in pairwise_distances_chunked(
            embeddings, metric="cosine", reduce_func=reduce_func, working_memory=100
        ):
            for query_indx, query_nearest_items in chunk_result:
                predictions.append(
                    (
                        trackids[query_indx],
                        [trackids[nn_indx] for nn_indx in query_nearest_items],
                    )
                )
        print("saving test....")
        save_test_predictions(predictions, output_dir=self.config["test"]["output_dir"])

    def overfit_check(self) -> None:
        print("overfit_check")

        # drive_model_path = "models"
        # self.best_model_path = os.path.join(drive_model_path, f"best-models_16_{str(uuid.uuid4())}.pt")

        # torch.save(deepcopy(self.model.state_dict()), self.best_model_path)

#         for file in os.listdir(drive_model_path):

#             filename = os.path.join(drive_model_path, file)
#             if filename != self.best_model_path:
#                 try:
#                     os.remove(filename)
#                 except Exception as e:
#                     print(f"The file {e} does not exist")


        if self.early_stop(self.postfix["mAP"]):
            print(
                f"\nValidation not improved for {self.early_stop.patience} consecutive epochs. Stopping..."
            )
            self.state = "early_stopped"

        if self.early_stop.counter > 0:
            print("\nValidation mAP was not improved")
        else:
            print(
                f"\nMetric improved. New best score: {self.early_stop.max_validation_mAP:.3f}"
            )

            # print("Saving models...")
            # epoch = self.postfix["Epoch"]
            # max_secs = self.max_len
            # prev_model = deepcopy(self.best_model_path)
            #### Оригинал
            # self.best_model_path = os.path.join(
            #     self.config["val"]["output_dir"],
            #     "models",
            #     f"best-models-{epoch=}-{max_secs=}.pt",
            # )
            # os.makedirs(os.path.dirname(self.best_model_path), exist_ok=True)
            # torch.save(deepcopy(self.model.state_dict()), self.best_model_path)


            # if prev_model is not None:
            #     os.remove(prev_model)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config["train"]["learning_rate"]
        )


        return optimizer

### - Тренировка

In [None]:
torch.cuda.empty_cache()
config = load_config(config_path="/content/drive/MyDrive/yandex_ml_olimp/config/config_colab_2.yaml")

trainer = TrainModule(config)
# trainer.pipeline()
trainer.test()

  self.scaler = torch.cuda.amp.GradScaler(
  torch.load(self.config["test"]["model_ckpt"], map_location=torch.device(self.config["device"])), strict=False


Model loaded from checkpoint: /content/drive/MyDrive/yandex_ml_olimp/models/best-resnet18_e4344df2-ae4e-42f1-a1ba-9f52f0740816.pt
start test procedure


100%|██████████| 1725/1725 [01:05<00:00, 26.15it/s]


saving test....
