In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from data import load_parquets_from_zip
import polars as pl
import numpy as np
import random
from torch.utils.data import TensorDataset, DataLoader
from tqdm.auto import tqdm
import os

In [3]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [4]:
imgs = load_parquets_from_zip('dataset/Ekstra_Bladet_image_embeddings.zip')['Ekstra_Bladet_image_embeddings/image_embeddings']

In [5]:
imgs_embs = torch.from_numpy(np.asarray(list(imgs['image_embedding'])))

In [6]:
dl = DataLoader(TensorDataset(imgs_embs), batch_size=512, shuffle=True)

In [7]:
class Normalize(nn.Module):
    
    def __init__(self):
        super().__init__()
        pass

    def forward(self, x):
        return F.normalize(x, dim=-1)

class EncodeDecoder(nn.Module):

    def __init__(self, in_size=1024, hidden=[512, 128]):
        super().__init__()
        layers = [in_size] + hidden
        self.encoder = nn.Sequential(*([nn.Linear(layers[i], layers[i+1], bias=False) for i in range(len(layers) - 1)] + [Normalize()]))
        layers = layers[::-1]
        self.decoder = nn.Sequential(*[nn.Linear(layers[i], layers[i+1], bias=False) for i in range(len(layers) - 1)])
        pass

    def forward(self, x):
        return self.decoder(self.encoder(x))

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = EncodeDecoder().to(device)

opt = torch.optim.Adam(model.parameters())

mse = nn.MSELoss()

In [9]:
model_path = 'preprocess/image_encoder.pth'
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:
    for e in range(1, 100):
        c_loss = 0
        c_step = 0
        with tqdm(dl, leave=False) as dll:
            for x in dll:
                opt.zero_grad()
                x = x[0].to(device)
                x_pred = model(x)
                loss = mse(x_pred, x)
                loss.backward()
                opt.step()
                c_loss += loss.item()
                c_step += 1
                dll.set_postfix(loss=c_loss / c_step)
        print(f'Epoch {e}: {c_loss / c_step }')
    torch.save(model.state_dict(), model_path)
model.eval()

  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 1: 4.091748173857072e-05


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 2: 7.995236991116339e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 3: 3.9625150006104017e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 4: 3.477615116163853e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 5: 3.1993745403328534e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 6: 2.7523079824960626e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 7: 2.420363729095978e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 8: 2.3389708515614034e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 9: 2.439666356386891e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 10: 2.081311887015762e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 11: 2.069445361652542e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 12: 2.0407342930608517e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 13: 2.0928900458565028e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 14: 2.0021308279155078e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 15: 1.989583797745554e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 16: 1.9622625578150875e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 17: 1.6951379599950133e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 18: 1.8010173201997279e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 19: 1.8918132410313112e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 20: 2.2482629878315075e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 21: 1.3918191303951944e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 22: 1.823546212030852e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 23: 1.6739141789207792e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 24: 1.645557400723789e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 25: 1.9467326215173934e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 26: 1.388733783447669e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 27: 1.8143369941870432e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 28: 1.4194491990649147e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 29: 1.6967185548871307e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 30: 1.7822093717009396e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 31: 1.3769312182427905e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 32: 1.632827181677937e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 33: 1.3628685138090177e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 34: 1.2162095519961123e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 35: 1.6197616730692492e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 36: 1.370913400975625e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 37: 1.4043852989100096e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 38: 1.2196300789202804e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 39: 1.5565381616016426e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 40: 1.2486973418783998e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 41: 1.1065315546762006e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 42: 1.275914777878241e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 43: 1.1750814330645735e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 44: 1.0892606762616704e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 45: 1.3332097279119205e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 46: 1.2566822366138627e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 47: 9.476592095237929e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 48: 1.446510878087512e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 49: 7.742726856788602e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 50: 1.1997928629597396e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 51: 1.0949965423942015e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 52: 1.2610222899720205e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 53: 8.352211782335754e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 54: 1.2593050329665467e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 55: 9.3266149518746e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 56: 1.2293961158767427e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 57: 7.75493916689042e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 58: 1.1997264374655577e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 59: 6.663853371821972e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 60: 1.0758061486866885e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 61: 9.628901752440602e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 62: 8.816797459948207e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 63: 1.0112421633511895e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 64: 8.418402383815408e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 65: 1.060568807524511e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 66: 7.659108547386495e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 67: 1.0121982438506474e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 68: 1.0867896115379058e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 69: 6.414336073898586e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 70: 1.0026520965396488e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 71: 6.836760045654222e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 72: 1.0488960828692216e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 73: 9.228169482077501e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 74: 6.305753320938226e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 75: 1.0495819658835085e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 76: 7.555385556769924e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 77: 1.0566449512155383e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 78: 5.756485276734154e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 79: 9.172783457068508e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 80: 8.262504614826053e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 81: 7.96744842970309e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 82: 7.444707772088211e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 83: 6.907513982360431e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 84: 1.4089717741449904e-07


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 85: 4.7790262290356115e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 86: 7.662198146278311e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 87: 7.511229387501654e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 88: 8.381989076509564e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 89: 7.139282817109266e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 90: 7.262293235765273e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 91: 8.369624922364603e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 92: 6.93054849993736e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 93: 6.846569429252605e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 94: 7.354550079831097e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 95: 9.342550440513589e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 96: 8.112770683878106e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 97: 4.8934624047728567e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 98: 8.286820037690844e-08


  0%|          | 0/201 [00:00<?, ?it/s]

Epoch 99: 6.589718893718922e-08


EncodeDecoder(
  (encoder): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=False)
    (1): Linear(in_features=512, out_features=128, bias=False)
    (2): Normalize()
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=False)
    (1): Linear(in_features=512, out_features=1024, bias=False)
  )
)

In [10]:
x = next(iter(dl))[0].to(device)

In [11]:
x

tensor([[-0.0325,  0.0286, -0.0039,  ...,  0.0428, -0.0180, -0.0141],
        [-0.0194, -0.0154, -0.0571,  ..., -0.0247,  0.0367,  0.0314],
        [-0.0080,  0.0134,  0.0466,  ...,  0.0630, -0.0346, -0.0435],
        ...,
        [-0.0051,  0.0427,  0.0274,  ...,  0.0329, -0.0334, -0.0161],
        [-0.0104,  0.0157,  0.0363,  ...,  0.0585, -0.0177, -0.0481],
        [-0.0461, -0.0466, -0.0732,  ..., -0.0009,  0.0293,  0.0025]],
       device='cuda:0')

In [12]:
model(x)

tensor([[-0.0324,  0.0286, -0.0039,  ...,  0.0429, -0.0181, -0.0143],
        [-0.0196, -0.0156, -0.0572,  ..., -0.0245,  0.0363,  0.0307],
        [-0.0077,  0.0134,  0.0467,  ...,  0.0633, -0.0348, -0.0433],
        ...,
        [-0.0050,  0.0430,  0.0274,  ...,  0.0330, -0.0339, -0.0162],
        [-0.0105,  0.0161,  0.0360,  ...,  0.0584, -0.0178, -0.0481],
        [-0.0465, -0.0460, -0.0730,  ..., -0.0010,  0.0293,  0.0022]],
       device='cuda:0', grad_fn=<MmBackward0>)

In [13]:
torch.std(x, dim=0)

tensor([0.0186, 0.0231, 0.0307,  ..., 0.0282, 0.0268, 0.0226], device='cuda:0')

In [14]:
torch.mean(torch.abs(x - model(x)), dim=0)

tensor([0.0002, 0.0002, 0.0001,  ..., 0.0002, 0.0001, 0.0002], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [15]:
torch.std(torch.abs(x - model(x)), dim=0)

tensor([0.0001, 0.0002, 0.0001,  ..., 0.0001, 0.0001, 0.0002], device='cuda:0',
       grad_fn=<StdBackward0>)

In [16]:
dl = DataLoader(TensorDataset(imgs_embs), batch_size=512, shuffle=False)

embs = []
with torch.no_grad():
    for x in dl:
        x = x[0].to(device)
        emb = model.encoder(x).cpu().numpy()
        embs.append(emb)

In [17]:
embs = np.concatenate(embs, axis=0)

In [18]:
embs

array([[ 0.11453101,  0.00137664, -0.01590084, ..., -0.02036735,
        -0.05598754,  0.0513007 ],
       [-0.0300234 ,  0.00979323, -0.00183117, ...,  0.0564303 ,
        -0.07377557,  0.09771489],
       [ 0.2675379 , -0.06965413, -0.18188432, ..., -0.1330114 ,
         0.00300317, -0.02472829],
       ...,
       [-0.05685404,  0.05275042,  0.06841292, ..., -0.02606313,
         0.0475052 ,  0.02816776],
       [ 0.0441104 ,  0.01659502, -0.03042564, ...,  0.06748616,
         0.01042796, -0.05028164],
       [-0.02042744,  0.08351305, -0.08589312, ...,  0.10827801,
        -0.06187802, -0.02863881]], dtype=float32)

In [19]:
embs = pl.DataFrame(data=embs, schema=['embeddings']).with_columns(imgs['article_id']).select('article_id', 'embeddings')

In [20]:
embs.write_parquet('preprocess/image_embs.parquet')