In [None]:
import sys
import timm

from itertools import zip_longest
import json
import math
import gc
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.transforms import Resize, RandomHorizontalFlip, ColorJitter, Normalize, Compose, RandomResizedCrop, \
    CenterCrop, ToTensor

from tqdm import tqdm
from PIL import Image
import joblib
from scipy.sparse import hstack, vstack, csc_matrix, csr_matrix

import networkx as nx
from transformers import BertConfig, BertModel, BertTokenizerFast

NUM_CLASSES = 11014
NUM_WORKERS = 2
SEED = 0
device = 'cuda'

def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)


class ShopeeNet(nn.Module):

    def __init__(self,
                 backbone,
                 num_classes,
                 fc_dim=512,
                 s=30, margin=0.5, p=3):
        super(ShopeeNet, self).__init__()

        self.backbone = backbone
        self.backbone.reset_classifier(num_classes=0)  # remove classifier

        self.fc = nn.Linear(self.backbone.num_features, fc_dim)
        self.bn = nn.BatchNorm1d(fc_dim)
        self._init_params()
        self.p = p

    def _init_params(self):
        nn.init.xavier_normal_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def extract_feat(self, x):
        batch_size = x.shape[0]
        x = self.backbone.forward_features(x)
        if isinstance(x, tuple):
            x = (x[0] + x[1]) / 2
            x = self.bn(x)
        else:
            x = gem(x, p=self.p).view(batch_size, -1)
            x = self.fc(x)
            x = self.bn(x)
        return x

    def forward(self, x, label):
        feat = self.extract_feat(x)
        x = self.loss_module(feat, label)
        return x, feat


def batch_images_generator(path_to_folder, bs, preprocessor, resume_from_batch=None):
    import glob
    from torchvision.io import read_image

    images = glob.glob('*.jpg', root_dir=path_to_folder)

    images = [image for image in images if os.path.getsize(path_to_folder + image) > 0]
    print(len(images))
    start = 0 if not resume_from_batch else resume_from_batch
    for i in tqdm(range(start, len(images), bs)):
        batch_images = images[i:i + bs]
        torch_batch = []
        for x in batch_images:
            try:
                torch_batch.append(transform(read_image(path_to_folder + x, mode='RGB').to('cuda').float() / 255)[None])
            except Exception as e:
                print(read_image(path_to_folder + x).shape)
                print(e)
                print(x)

        torch_batch = torch.cat(torch_batch).to(device)

        yield i, torch_batch, batch_images

In [None]:
checkpoint1 = torch.load('weights/top2/v45.pth')
checkpoint2 = torch.load('weights/top2/v34.pth')

params1 = checkpoint1['params']
params2 = checkpoint2['params']
params1['backbone'] = 'deit_base_distilled_patch16_384.fb_in1k'

transform = Compose([
    Resize(size=params1['test_size'] + 32, interpolation=Image.BICUBIC),
    CenterCrop((params1['test_size'], params1['test_size'])),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# backbone = timm.create_model(model_name=params1['backbone'], pretrained=False)
# model1 = ShopeeNet(backbone, num_classes=0, fc_dim=params1['fc_dim'])
# model1 = model1.to('cuda')
# model1.load_state_dict(checkpoint1['model'], strict=False)
# model1.train(False)
# model1.p = params1['p_eval']

backbone = timm.create_model(model_name=params2['backbone'], pretrained=False)
model2 = ShopeeNet(backbone, num_classes=0, fc_dim=params2['fc_dim'])
model2 = model2.to('cuda')
model2.load_state_dict(checkpoint2['model'], strict=False)
model2.train(False)
model2.p = params2['p_eval']

img_feats1 = []
img_feats2 = []

img_hs = []
img_ws = []
st_sizes = []

batch_size = 64
for t in ['train', 'test']:
    output_dir = f'avito/images/{t}/parquets'
    # os.mkdir(output_dir, exist)
    # os.mkdir(output_dir, exist_ok=True)
    images_dir = f'avito/images/{t}/images/'

    for i, batch, fname in batch_images_generator(images_dir, bs=batch_size, preprocessor=None):

        with torch.no_grad():
            # feats_minibatch1 = model1.extract_feat(batch)
            # img_feats1.append(feats_minibatch1.cpu().numpy())
            feats_minibatch2 = model2.extract_feat(batch).tolist()

        df = pd.DataFrame({
            "filename": fname,
            "embedding": feats_minibatch2  # list of np.array, stored as arrays
        })

        df.to_parquet(
            os.path.join(output_dir, f"batch_{i // batch_size}.parquet"),
            index=False
        )
        
    import duckdb
    duckdb.sql(rf"""
        COPY (
            SELECT * FROM '{output_dir}/*.parquet'
        )
        TO '{output_dir}/final_embeddings.parquet' (FORMAT PARQUET)
    """)