In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'

import sys

import numpy as np
import torch

from tqdm.auto import tqdm

In [None]:
from optimizers import Adan, Lookahead, AGC

In [None]:
LR = 1e-3
BATCH_SIZE = 512

## load

In [None]:
# pip install scikit-learn mega.py

sys.path.append(os.path.abspath('../kcg-ml-vae-test/'))
from utilities.utils import read_embedding_data, read_msg_pack

In [None]:
import glob

paths = sorted(glob.glob('/workspace/kk-digital/kcg-ml-image-pipeline/output/environmental/ranking_v2/embeddings/*_embedding.msgpack'))

pos_embs = []
for path in tqdm(paths):
    pos_emb, neg_emb = read_embedding_data(path)
    pos_embs.append(pos_emb)
pos_embs = np.concatenate(pos_embs, axis=0)

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_data, val_data = train_test_split(pos_embs, test_size=0.2, random_state=42)

In [233]:
class Model(torch.nn.Module):
    def __init__(self, input_dim, seq_len, hidden_dim):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv1d(input_dim, 384, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(384),
            torch.nn.ReLU(),
            torch.nn.Conv1d(384, 128, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Conv1d(128, 64, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Conv1d(64, 32, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(32),
            torch.nn.ReLU()
        )
        
        self.encoder2 = torch.nn.Sequential(
            torch.nn.Linear(seq_len * 32, hidden_dim, bias=False),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU()
        )
        
        self.dropout = torch.nn.Dropout(0.5)
        
        self.decoder2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, seq_len * 32, bias=False),
            torch.nn.BatchNorm1d(seq_len * 32),
            torch.nn.ReLU()
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv1d(32, 64, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Conv1d(64, 128, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Conv1d(128, 384, kernel_size=1, bias=False, padding='same'),
            torch.nn.BatchNorm1d(384),
            torch.nn.ReLU(),
            torch.nn.Conv1d(384, input_dim, kernel_size=1, bias=False, padding='same')
        )

    def forward(self, x):

        z = self.encoder(x)
        seq_shape = z.shape
        z = z.view(z.shape[0], -1)
        z = self.dropout(z)
        z = self.encoder2(z)
        z = self.dropout(z)
        z = self.decoder2(z)
        z = z.view(*seq_shape)
        y = self.decoder(z)

        return y

In [234]:
model = Model(768, 77, 1024)

In [235]:
model.encoder.load_state_dict(torch.load(os.path.join('weight/002/encoder.pt')))
model.decoder.load_state_dict(torch.load(os.path.join('weight/002/decoder.pt')))

<All keys matched successfully>

In [236]:
model.encoder2.load_state_dict(torch.load(os.path.join('weight/002/encoder2.pt')))
model.decoder2.load_state_dict(torch.load(os.path.join('weight/002/decoder2.pt')))

<All keys matched successfully>

In [237]:
parameters = list(model.parameters())
# parameters = list(model.encoder2.parameters()) + list(model.decoder2.parameters())

optimizer = Adan(parameters, lr=LR, weight_decay=1e-3)
optimizer = Lookahead(optimizer)
optimizer = AGC(optimizer)
warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, [lambda step: step / 100. if step < 100 else 1.])

In [238]:
model = torch.nn.DataParallel(model.cuda())

In [239]:
train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_data).permute(0, 2, 1))
val_dataset = torch.utils.data.TensorDataset(torch.tensor(val_data).permute(0, 2, 1))

In [240]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, drop_last=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [261]:
scaler = torch.cuda.amp.GradScaler()
    
mses, coss = list(), list()

for epoch in tqdm(range(1000)):
    
    model.train()
    
    for (x,) in train_loader:
        
        x = x.half().cuda()

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(True):

            y = model(x)

        # backward

        mse = torch.nn.functional.mse_loss(x, y)
        
        cos = torch.nn.functional.cosine_similarity(x, y).mean()
        
        loss = mse - cos

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        warmup.step()
        
        mses.append(mse.detach().cpu().numpy())
        coss.append(cos.detach().cpu().numpy())
        
    if (epoch + 1) % 100 == 0:
        
        model.eval()

        val_mses, val_coss = list(), list()

        with torch.no_grad():

            for (x,) in val_loader:

                x = x.half().cuda()

                with torch.cuda.amp.autocast(True):

                    y = model(x)

                mse = torch.nn.functional.mse_loss(x, y)

                cos = torch.nn.functional.cosine_similarity(x, y).mean()

                val_mses.append(mse.detach().cpu().numpy())
                val_coss.append(cos.detach().cpu().numpy())

        print(f'{np.mean(mses):.4f} {np.mean(coss):.4f} {np.mean(val_mses):.4f} {np.mean(val_coss):.4f}')
    
        mses, coss = list(), list()

  0%|          | 0/1000 [00:00<?, ?it/s]

0.3718 0.7817 0.4585 0.7236
0.3713 0.7822 0.4585 0.7236
0.3708 0.7822 0.4597 0.7231
0.3706 0.7827 0.4575 0.7246
0.3701 0.7827 0.4590 0.7236
0.3699 0.7832 0.4595 0.7236
0.3694 0.7832 0.4590 0.7236
0.3691 0.7837 0.4602 0.7227
0.3687 0.7837 0.4600 0.7231
0.3684 0.7842 0.4609 0.7227


In [259]:
torch.save(model.module.encoder.state_dict(), os.path.join('weight/002/encoder.pt'))
torch.save(model.module.decoder.state_dict(), os.path.join('weight/002/decoder.pt'))

In [260]:
torch.save(model.module.encoder2.state_dict(), os.path.join('weight/002/encoder2.pt'))
torch.save(model.module.decoder2.state_dict(), os.path.join('weight/002/decoder2.pt'))

In [138]:
for i in model.module.encoder2:
    model.module.encoder.add_module(str(len(model.module.encoder)), i)

In [139]:
for i in model.module.decoder:
    model.module.decoder2.add_module(str(len(model.module.decoder2)), i)

In [209]:
from sklearn.preprocessing import normalize

In [253]:
X = x.permute(0, 2, 1).detach().cpu().numpy()

In [254]:
X = X.reshape(-1, 768)
Y, l2_norm = normalize(X, norm='l2', axis=1, return_norm=True)

In [255]:
l2_norm.shape

(4543,)

In [256]:
torch.nn.functional.mse_loss(x, y)

tensor(0.4646, device='cuda:0', dtype=torch.float16)

In [257]:
norm = torch.tensor(l2_norm).cuda().view(x.shape[0], 1, x.shape[-1])

In [258]:
torch.nn.functional.mse_loss(x / norm, y / norm, reduction='sum') / x.shape[0]

tensor(35.6875, device='cuda:0', dtype=torch.float16)