In [84]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [85]:
from data_process import *
from tqdm.auto import tqdm
from torch.utils.data import Dataset
from functools import partial
from itertools import chain
import polars as pl
import numpy as np
import os
from collections import defaultdict, Counter
from typing import Tuple, Dict, List, Any, Callable, Optional, Union, Iterable
from datetime import datetime
from bisect import bisect
import math
import sentence_transformers as st

base_dir = 'preprocess'
if not os.path.exists(base_dir):
    os.makedirs(base_dir)

In [86]:
# Processes the article DataFrame by embedding categorical features, encoding multi-value columns, 
# and generating lookup dictionaries for article IDs and publication dates.
def preprocess_article(
    article_df: pl.DataFrame, 
    min_reps: int = 100, 
    image_ids: Optional[Dict[int, int]] = None, 
    model_name: str = 'paraphrase-multilingual-mpnet-base-v2'
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame, Dict[int, int], Dict[int, Any]]:
    # Initialize the SentenceTransformer model
    model = st.SentenceTransformer(model_name)
    ops = []

    def encode_and_create_lookup(df: pl.DataFrame, column: str, embedding_model, schema: str):
        """Helper function to encode and create a lookup for a specific column."""
        unique_values = sorted(set(df[column].drop_nulls().to_list()))
        mapping = {val: idx for idx, val in enumerate(unique_values)}
        embeddings = embedding_model.encode(unique_values)
        encoded_df = pl.DataFrame(data={column: unique_values, 'embeddings': embeddings})
        return mapping, encoded_df

    # 1. Encode and link category_str
    categories_map, categories_df = encode_and_create_lookup(article_df, 'category_str', model, 'category_link')
    ops.append(pl.col('category_str').replace(categories_map).cast(pl.Int64).alias('category_link'))

    # 2. Encode and link ner_clusters
    ner_values = sorted(set(val for sublist in article_df['ner_clusters'].drop_nulls() for val in sublist))
    ner_map = {val: idx for idx, val in enumerate(ner_values)}
    embeddings = model.encode(ner_values)
    ner_df = pl.DataFrame(data={'ners': ner_values, 'embeddings': embeddings})
    ops.append(pl.col('ner_clusters').map_elements(
        lambda x: [ner_map[n] for n in x] if x is not None else [], 
        return_dtype=pl.List(pl.Int64), 
        skip_nulls=False
    ).alias('ner_clusters_link'))

    # 3. Encode and link topics
    topics_map, topics_df = encode_and_create_lookup(article_df, 'topics', model, 'topics_link')
    ops.append(pl.col('topics').map_elements(
        lambda x: [topics_map[n] for n in x] if x is not None else [], 
        return_dtype=pl.List(pl.Int64), 
        skip_nulls=False
    ).alias('topics_link'))

    # Apply encoding operations to article_df
    article_df = article_df.with_columns(ops)

    # Clean and preprocess ner_clusters
    article_df = article_df.with_columns(
        pl.col('ner_clusters').map_elements(
            lambda x: [n.lower()[:3] for n in x] if x is not None else [], 
            return_dtype=pl.List(pl.String), 
            skip_nulls=False
        )
    )

    # Create mappings for article_id to published_time and article_id to index
    articleid_date = dict(article_df.select('article_id', 'published_time').rows())
    articleid_index = {article_id: idx for idx, article_id in enumerate(article_df['article_id'])}

    # Process specific features in the DataFrame
    ops = [
        pl.col('premium').map_elements(lambda x: int(bool(x)), return_dtype=pl.Int64),
        (pl.col('published_time').dt.weekday() - 1).alias('pub_weekday'),
        pl.col('published_time').dt.hour().alias('pub_hour')
    ]

    # List of features to be processed
    features = [
        ('article_type', False, True),
        ('subcategory', True, False),
        ('sentiment_label', False, False),
        ('category', False, True),
        ('ner_clusters', True, False),
        ('topics', True, False)
    ]

    for feature, is_list, allow_unknown in features:
        map_feature, valid_values = get_map_for_feature(article_df, feature, min_reps, is_list_column=is_list, unknown=allow_unknown)
        
        if is_list:
            valid_set = set(valid_values)
            ops.append(pl.col(feature).map_elements(
                lambda x: [map_feature(v) for v in x if v in valid_set] if x is not None else [], 
                return_dtype=pl.List(pl.Int64), 
                skip_nulls=False
            ))
        else:
            ops.append(pl.col(feature).map_elements(map_feature, return_dtype=pl.Int64))
        
        print(f'Feature {feature} encoded. Number of values: {(len(valid_values) + 1) if allow_unknown else len(valid_values)}')

    if image_ids is not None:
        ops.append(pl.col('image_ids').map_elements(
            lambda x: [image_ids[v] for v in x if v in image_ids] if x is not None else [], 
            return_dtype=pl.List(pl.Int64), 
            skip_nulls=False
        ))

    article_df = article_df.with_columns(ops)

    return article_df, categories_df, ner_df, topics_df, articleid_index, articleid_date

In [88]:
# Computes the delta time between article publication and user impression for each article in view.
# Also, computes the total number of articles in view for each impression.
def compute_behaviors_dates(behavior_df: pl.DataFrame, article_publish: Dict[int, datetime]) -> pl.DataFrame:
    # Calculate the time delta for each article in view relative to its publication time
    behavior_df = behavior_df.with_columns([
        pl.struct(['impression_time', 'article_ids_inview']).map_elements(
            lambda x: [
                int((x['impression_time'] - article_publish[a][0]).total_seconds()) 
                for a in x['article_ids_inview']
            ] if x['article_ids_inview'] else [],
            return_dtype=pl.List(pl.Int64)
        ).alias('article_delta_time'),
        
        # Count the number of articles in view
        pl.col('article_ids_inview').map_elements(
            lambda x: len(x) if x else 0, 
            return_dtype=pl.Int64
        ).alias('number_articles')
    ])
    
    return behavior_df

In [89]:
def compute_history_dates(history_df: pl.DataFrame, article_publish: Dict[int, datetime], sorted_history: bool=False) -> pl.DataFrame:
    if sorted_history:
        history_df = history_df.with_columns(order=pl.col('impression_time_fixed').map_elements(ids_sort, return_dtype=pl.List(pl.Int64)))
        history_df = history_df.with_columns(
                        impression_time_fixed=pl.struct('impression_time_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['impression_time_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Datetime)),
                        scroll_percentage_fixed=pl.struct('scroll_percentage_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['scroll_percentage_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Float64)),
                        article_id_fixed=pl.struct('article_id_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['article_id_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Int64)),
                        read_time_fixed=pl.struct('read_time_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['read_time_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Float64))). \
                        drop('order')
    history_df = history_df.with_columns(
        impression_weekday=pl.col('impression_time_fixed'). 
            map_elements(lambda x: [d.weekday() for d in x], return_dtype=pl.List(pl.Int64)),
        impression_hour=pl.col('impression_time_fixed'). 
            map_elements(lambda x: [d.hour for d in x], return_dtype=pl.List(pl.Int64)),
        article_delta_time=pl.struct(('impression_time_fixed', 'article_id_fixed')). 
            map_elements(lambda x:[ int((d - article_publish[a][0]).total_seconds()) for d, a in zip(x['impression_time_fixed'], x['article_id_fixed'])], 
                         return_dtype=pl.List(pl.Int64)),
        number_articles=pl.col('article_id_fixed').map_elements(len, return_dtype=pl.Int64)
    )
    return history_df

In [90]:
# Preprocesses the behavior and history DataFrames by computing time-based features, 
# encoding categorical features, and applying quantile-based binning.
def preprocess(
    behavior: pl.DataFrame, 
    history: pl.DataFrame, 
    article_publish: Dict[int, datetime], 
    article_index: Dict[int, int], 
    sorted_history: bool = False, 
    time_quartiles: Optional[List[int]] = None
) -> Tuple[pl.DataFrame, pl.DataFrame, List[int]]:
    # Step 1: Compute time-related features for history and behavior
    print('Computing history & behavior dates')
    history = compute_history_dates(history, article_publish, sorted_history)
    behavior = compute_behaviors_dates(behavior, article_publish)

    # Step 2: Compute user ID index mapping and time quantile limits
    print('Computing history user_id mapping and time percentiles')
    historyid_index = {user_id: i for i, user_id in enumerate(history['user_id'])}

    if time_quartiles is None:
        # Replace None with an empty list
        time_quartiles = quantile_limits(chain(*[x for x in history['article_delta_time'].to_list() if x is not None]))

    # Step 3: Process the behavior DataFrame
    print('Preprocessing behavior')
    behavior_ops = []

    if 'next_scroll_percentage' in behavior.columns:
        behavior_ops.append(
            pl.col('next_scroll_percentage').map_elements(
                lambda x: x / 100 if x is not None else 0.0, 
                return_dtype=pl.Float64, 
                skip_nulls=False
            )
        )
        behavior_ops.append(
            pl.struct(['article_ids_inview', 'article_ids_clicked']).map_elements(
                lambda x: [x['article_ids_inview'].index(v) for v in x['article_ids_clicked']],
                return_dtype=pl.List(pl.Int64)
            ).alias('article_ids_clicked')
        )

    categorical_mappings = [
        ('user_id', historyid_index, None),
        ('is_sso_user', None, binary_encoding),
        ('gender', None, build_dict_encoding([0, 1, 2, None])),
        ('postcode', None, build_dict_encoding([0, 1, 2, 3, 4, None])),
        ('age', None, build_dict_encoding([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, None])),
        ('is_subscriber', None, binary_encoding),
        ('device_type', None, build_dict_encoding([0, 1, 2, 3])),
    ]

    for column, mapping, encoding_function in categorical_mappings:
        if mapping:
            behavior_ops.append(pl.col(column).replace(mapping))
        elif encoding_function:
            behavior_ops.append(pl.col(column).map_elements(encoding_function, return_dtype=pl.Int8))

    behavior_ops.extend([
        pl.col('article_delta_time').map_elements(
            lambda x: [time_encoding(v, time_quartiles) for v in x], 
            return_dtype=pl.List(pl.Int64)
        ),
        pl.col('article_ids_inview').map_elements(
            lambda x: [article_index[v] for v in x], 
            return_dtype=pl.List(pl.Int64)
        ),
        pl.col('impression_time').dt.hour().alias('impression_hour'),
        (pl.col('impression_time').dt.weekday() - 1).alias('impression_weekday')
    ])

    behavior = behavior.with_columns(behavior_ops)

    # Step 4: Process the history DataFrame
    print('Preprocessing history')
    history_ops = [
        pl.col('scroll_percentage_fixed').map_elements(
            lambda l: [x / 100 if x is not None else 0.0 for x in l], 
            return_dtype=pl.List(pl.Float64)
        ),
        pl.col('article_delta_time').map_elements(
            lambda x: [time_encoding(v, time_quartiles) for v in x], 
            return_dtype=pl.List(pl.Int64)
        ),
        pl.col('article_id_fixed').map_elements(
            lambda x: [article_index[v] for v in x], 
            return_dtype=pl.List(pl.Int64)
        )
    ]

    history = history.with_columns(history_ops)

    return behavior, history, time_quartiles

In [91]:
ds = load_parquets('datasets/ebnerd_demo.zip')
behavior = ds['train/behaviors']
history = ds['train/history']
article = ds['articles']

del ds

In [92]:
article, categories_embs, ner_embs, topics_embs, articleid_indexes, article_publish_date = preprocess_article(article, 1000)



Feature article_type encoded. Number of values: 2
Feature subcategory encoded. Number of values: 2
Feature sentiment_label encoded. Number of values: 3
Feature category encoded. Number of values: 5
Feature ner_clusters encoded. Number of values: 3
Feature topics encoded. Number of values: 13


In [93]:
behavior, history, time_quartiles = preprocess(behavior, history, article_publish_date, articleid_indexes)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history


In [94]:
behavior.head(15)

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,article_delta_time,number_articles,impression_hour,impression_weekday
u32,i32,datetime[μs],f32,f32,i8,list[i64],list[i64],u32,i8,i8,i8,i8,i8,u32,f32,f64,list[i64],i64,i8,i8
48401,,2023-05-21 21:06:50,21.0,,2,"[9657, 9360, … 8514]",[10],27,0,,,,0,21,16.0,0.27,"[56, 14, … 55]",,21,6
152513,9778745,2023-05-24 07:31:26,30.0,100.0,1,"[10004, 10014, … 9918]",[4],741,0,,,,0,298,2.0,0.48,"[49, 3, … 64]",,7,2
155390,,2023-05-24 07:30:33,45.0,,1,"[9987, 9945, … 9994]",[1],8,0,,,,0,401,215.0,1.0,"[72, 76, … 76]",,7,2
214679,,2023-05-23 05:25:40,33.0,,2,"[9861, 9837, … 9864]",[2],574,0,,,,0,1357,40.0,0.47,"[58, 13, … 19]",,5,1
214681,,2023-05-23 05:31:54,21.0,,2,"[9706, 9864, … 9853]",[5],574,0,,,,0,1358,5.0,0.49,"[95, 24, … 85]",,5,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
371838,,2023-05-24 14:23:00,108.0,,2,"[10033, 10062, … 10072]",[11],79,0,,,,0,2015,1270.0,0.29,"[76, 60, … 23]",,14,2
374777,9779289,2023-05-24 14:15:08,12.0,100.0,2,"[10072, 10074, … 9744]",[2],1123,0,,,,0,2470,11.0,0.0,"[17, 25, … 96]",,14,2
386918,,2023-05-20 20:41:22,26.0,,2,"[9586, 9629, … 9631]",[0],1511,0,,,,0,2512,128.0,1.0,"[48, 10, … 28]",,20,5
780811,,2023-05-21 18:40:56,28.0,,2,"[9708, 9703, … 9667]",[6],312,0,,,,0,3639,26.0,1.0,"[15, 18, … 4]",,18,6


In [95]:
history.head(15)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed,impression_weekday,impression_hour,article_delta_time,number_articles
u32,list[datetime[μs]],list[f64],list[i64],list[f32],list[i64],list[i64],list[i64],i64
13538,"[2023-04-27 10:17:43, 2023-04-27 10:18:01, … 2023-05-17 20:36:34]","[1.0, 0.35, … 1.0]","[6391, 6382, … 9222]","[17.0, 12.0, … 16.0]","[3, 3, … 2]","[10, 10, … 20]","[4, 31, … 37]",582
58608,"[2023-04-27 18:48:09, 2023-04-27 18:48:45, … 2023-05-17 19:46:40]","[0.37, 0.61, … 0.0]","[6470, 6444, … 9311]","[2.0, 24.0, … 0.0]","[3, 3, … 2]","[18, 18, … 19]","[7, 57, … 79]",151
95507,"[2023-04-27 15:20:28, 2023-04-27 15:20:47, … 2023-05-17 14:57:46]","[0.6, 1.0, … 0.0]","[6427, 6387, … 9234]","[18.0, 29.0, … 0.0]","[3, 3, … 2]","[15, 15, … 14]","[42, 37, … 70]",370
106588,"[2023-04-27 08:29:09, 2023-04-27 08:29:26, … 2023-05-16 05:50:52]","[0.24, 0.57, … 1.0]","[6357, 6354, … 7318]","[9.0, 15.0, … 33.0]","[3, 3, … 1]","[8, 8, … 5]","[39, 47, … 95]",149
617963,"[2023-04-27 14:42:25, 2023-04-27 14:43:10, … 2023-05-18 02:28:09]","[1.0, 1.0, … 0.9]","[6427, 6434, … 9349]","[45.0, 29.0, … 22.0]","[3, 3, … 3]","[14, 14, … 2]","[19, 1, … 78]",277
…,…,…,…,…,…,…,…,…
171559,"[2023-04-28 09:33:17, 2023-04-28 09:34:23, … 2023-05-17 22:04:20]","[0.93, 1.0, … 1.0]","[6563, 6482, … 9256]","[17.0, 72.0, … 48.0]","[4, 4, … 2]","[9, 9, … 22]","[4, 75, … 81]",127
189524,"[2023-04-27 07:12:05, 2023-04-27 07:12:54, … 2023-05-18 05:40:40]","[0.57, 0.33, … 1.0]","[6355, 6326, … 9355]","[13.0, 13.0, … 106.0]","[3, 3, … 3]","[7, 7, … 5]","[11, 72, … 0]",270
373598,"[2023-04-27 22:39:22, 2023-04-27 23:12:03, … 2023-05-18 02:37:56]","[0.14, 1.0, … 0.7]","[6256, 6420, … 9282]","[45.0, 79.0, … 7.0]","[3, 3, … 3]","[22, 23, … 2]","[38, 83, … 90]",196
383378,"[2023-04-28 10:59:24, 2023-04-28 11:02:37, … 2023-05-17 21:56:19]","[0.0, 0.99, … 1.0]","[6584, 6555, … 9349]","[192.0, 0.0, … 205.0]","[4, 4, … 2]","[10, 11, … 21]","[18, 2, … 12]",750


In [96]:
article.head()

article_id,title,subtitle,last_modified_time,premium,body,published_time,image_ids,article_type,url,ner_clusters,entity_groups,topics,category,subcategory,category_str,total_inviews,total_pageviews,total_read_time,sentiment_score,sentiment_label,category_link,ner_clusters_link,topics_link,pub_weekday,pub_hour
i32,str,str,datetime[μs],i64,str,datetime[μs],list[i64],i64,str,list[i64],list[str],list[i64],i64,list[i64],str,i32,i32,f32,f32,i64,i64,list[i64],list[i64],i8,i8
3037230,"""Ishockey-spiller: Jeg troede j…","""ISHOCKEY: Ishockey-spilleren S…",2023-06-29 06:20:57,0,"""Ambitionerne om at komme til U…",2003-08-28 08:55:00,,0,"""https://ekstrabladet.dk/sport/…",[],[],"[4, 3, 9]",2,[],"""sport""",,,,0.9752,0,22,[],"[25, 21, … 37]",3,8
3044020,"""Prins Harry tvunget til dna-te…","""Hoffet tvang Prins Harry til a…",2023-06-29 06:21:16,0,"""Den britiske tabloidavis The S…",2005-06-29 08:47:00,"[3097307, 3097197, 3104927]",0,"""https://ekstrabladet.dk/underh…",[],"[""PER"", ""PER""]","[4, 3, … 6]",3,[],"""underholdning""",,,,0.7084,0,23,"[11216, 13248]","[25, 21, … 47]",2,8
3057622,"""Rådden kørsel på blå plader""","""Kan ikke straffes: Udenlandske…",2023-06-29 06:21:24,0,"""Slingrende spritkørsel. Grove …",2005-10-10 07:20:00,[3047102],0,"""https://ekstrabladet.dk/nyhede…",[],[],[4],0,[0],"""nyheder""",,,,0.9236,0,12,[],"[25, 66, 4]",0,7
3073151,"""Mærsk-arvinger i livsfare""","""FANGET I FLODBØLGEN: Skibsrede…",2023-06-29 06:21:38,0,"""To oldebørn af skibsreder Mærs…",2005-01-04 06:59:00,"[3067474, 3067478, 3153705]",0,"""https://ekstrabladet.dk/nyhede…",[],[],"[1, 8, 5]",0,[0],"""nyheder""",,,,0.9945,0,12,[],"[10, 50, … 52]",1,6
3193383,"""Skød svigersøn gennem babydyne""","""44-årig kvinde tiltalt for dra…",2023-06-29 06:22:57,0,"""En 44-årig mormor blev i dag f…",2003-09-15 15:30:00,,0,"""https://ekstrabladet.dk/krimi/…",[],[],"[4, 6]",1,[],"""krimi""",,,,0.9966,0,9,[],"[25, 47]",0,15


In [97]:
behavior.write_parquet('preprocess/behaviors_train.parquet')
history.write_parquet('preprocess/history_train.parquet')
article.write_parquet('preprocess/article.parquet')
categories_embs.write_parquet('preprocess/categories_embeddings.parquet')
ner_embs.write_parquet('preprocess/ner_embeddings.parquet')
topics_embs.write_parquet('preprocess/topics_embeddings.parquet')

In [98]:
dataset = load_parquets('datasets/ebnerd_demo.zip')
behavior_validation = dataset['validation/behaviors']
history_validation = dataset['validation/history']
del dataset

behavior_validation, history_validation, _ = preprocess(behavior_validation, history_validation, article_publish_date, articleid_indexes, time_quartiles=time_quartiles)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history


In [99]:
behavior_validation.write_parquet('preprocess/behaviors_validation.parquet')
history_validation.write_parquet('preprocess/history_validation.parquet')

del behavior_validation
del history_validation

In [100]:
dataset = load_parquets('datasets/ebnerd_testset.zip')
behavior_test = dataset['ebnerd_testset/test/behaviors']
history_test = dataset['ebnerd_testset/test/history']
del dataset

behavior_test, history_test, _ = preprocess(behavior_test, history_test, article_publish_date, articleid_indexes, time_quartiles=time_quartiles)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history


In [101]:
behavior_test.write_parquet('preprocess/behaviors_test.parquet')
history_test.write_parquet('preprocess/history_test.parquet')

del behavior_test
del history_test

In [104]:
images = load_parquets('datasets/Ekstra_Bladet_image_embeddings.zip')['Ekstra_Bladet_image_embeddings/image_embeddings']
image_embeddings = torch.from_numpy(np.asarray(list(images['image_embedding'])))

In [105]:
images.head()

article_id,image_embedding
i32,list[f32]
9734738,"[-0.023273, -0.039152, … -0.018453]"
8647636,"[-0.04257, -0.03809, … -0.017567]"
9715678,"[0.012374, 0.02872, … 0.006563]"
4001699,"[-0.026271, 0.043816, … -0.024919]"
4127160,"[-0.029413, 0.012622, … 0.007242]"


In [108]:
image_embeddings

tensor([[-0.0233, -0.0392, -0.0225,  ...,  0.0141,  0.0508, -0.0185],
        [-0.0426, -0.0381, -0.0065,  ..., -0.0060,  0.0319, -0.0176],
        [ 0.0124,  0.0287,  0.0160,  ...,  0.0137, -0.0131,  0.0066],
        ...,
        [-0.0118,  0.0167, -0.0028,  ...,  0.0503, -0.0552, -0.0017],
        [-0.0341, -0.0155, -0.0303,  ...,  0.0745, -0.0248, -0.0297],
        [-0.0296,  0.0014, -0.0294,  ...,  0.0188, -0.0003, -0.0026]])

In [109]:
from torch.utils.data import DataLoader, TensorDataset

data_loader = DataLoader(TensorDataset(image_embeddings), batch_size=512, shuffle=True)

In [110]:
from torch import nn
import torch.nn.functional as F

class Normalize(nn.Module):
    
    def __init__(self):
        super().__init__()
        pass

    def forward(self, x):
        return F.normalize(x, dim=-1)

class EncodeDecoder(nn.Module):

    def __init__(self, in_size=1024, hidden=[512, 128]):
        super().__init__()
        layers = [in_size] + hidden
        self.encoder = nn.Sequential(*([nn.Linear(layers[i], layers[i+1], bias=False) for i in range(len(layers) - 1)] + [Normalize()]))
        layers = layers[::-1]
        self.decoder = nn.Sequential(*[nn.Linear(layers[i], layers[i+1], bias=False) for i in range(len(layers) - 1)])
        pass

    def forward(self, x):
        return self.decoder(self.encoder(x))

In [126]:
def setup_device():
    """Setup device (CPU or GPU) for training."""
    return 'cuda' if torch.cuda.is_available() else 'cpu'


def load_or_initialize_model(model, model_path, device):
    """Load the model from the saved checkpoint or initialize a new one."""
    if os.path.exists(model_path):
        print(f"Loading model from {model_path}")
        model.load_state_dict(torch.load(model_path))
    else:
        print(f"Training a new model. No checkpoint found at {model_path}")
    return model.to(device)


def train_model(model, optimizer, data_loader, device, loss_fn, epochs=100, model_path='preprocess/image_encoder.pth'):
    """Train the model for a specified number of epochs and save its state."""
    for epoch in range(1, epochs):
        epoch_loss = train_one_epoch(model, optimizer, data_loader, device, loss_fn, epoch)
        print(f'Epoch {epoch}: Loss = {epoch_loss}')

    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")


def train_one_epoch(model, optimizer, data_loader, device, loss_fn, epoch):
    """Train the model for a single epoch."""
    # model.train()
    cumulative_loss = 0
    num_steps = 0

    # with tqdm(data_loader, leave=False) as progress_bar:
    for batch in data_loader:
        optimizer.zero_grad()
        x = batch[0].to(device)
        predictions = model(x)
        loss = loss_fn(predictions, x)
        loss.backward()
        optimizer.step()
        
        cumulative_loss += loss.item()
        num_steps+=1
            # progress_bar.set_postfix(epoch=epoch, loss=cumulative_loss / num_steps)
    
    return cumulative_loss / num_steps

In [127]:
from torch import optim

device = setup_device()
    
# Initialize model, optimizer, and loss function
model = EncodeDecoder().to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.MSELoss()

# Model path
model_path = 'preprocess/image_encoder.pth'

# Load model if checkpoint exists, else start from scratch
model = load_or_initialize_model(model, model_path, device)

# Train model
train_model(model, optimizer, data_loader, device, loss_fn, epochs=100, model_path=model_path)

model.eval()

Training a new model. No checkpoint found at preprocess/image_encoder.pth
Epoch 1: Loss = 4.079316700447078e-05
Epoch 2: Loss = 8.244247631769687e-07
Epoch 3: Loss = 3.917144327156686e-07
Epoch 4: Loss = 3.4455740194007645e-07
Epoch 5: Loss = 3.114017838990675e-07
Epoch 6: Loss = 2.651395495594266e-07
Epoch 7: Loss = 2.367855713250199e-07
Epoch 8: Loss = 2.2252195863493578e-07
Epoch 9: Loss = 2.3785028440557309e-07
Epoch 10: Loss = 2.0179111753809572e-07
Epoch 11: Loss = 2.084538561266934e-07
Epoch 12: Loss = 2.0565507495304238e-07
Epoch 13: Loss = 1.9936636775736728e-07
Epoch 14: Loss = 1.9721312109292435e-07
Epoch 15: Loss = 1.9209090455366657e-07
Epoch 16: Loss = 1.9032779901813476e-07
Epoch 17: Loss = 1.8229686302231528e-07
Epoch 18: Loss = 1.7678653114330135e-07
Epoch 19: Loss = 1.781364825200765e-07
Epoch 20: Loss = 1.8526445638849396e-07
Epoch 21: Loss = 1.9245801805302326e-07
Epoch 22: Loss = 1.693525711661281e-07
Epoch 23: Loss = 1.9214004968386522e-07
Epoch 24: Loss = 1.62408

EncodeDecoder(
  (encoder): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=False)
    (1): Linear(in_features=512, out_features=128, bias=False)
    (2): Normalize()
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=False)
    (1): Linear(in_features=512, out_features=1024, bias=False)
  )
)

In [129]:
data_loader = DataLoader(TensorDataset(image_embeddings), batch_size=512, shuffle=False)

embeddings = []
with torch.no_grad():
    for x in data_loader:
        x = x[0].to(device)
        embedding = model.encoder(x).cpu().numpy()
        embeddings.append(embedding)

embeddings = np.concatenate(embeddings, axis=0)
embeddings = pl.DataFrame(data=embeddings, schema=['embeddings']).with_columns(images['article_id']).select('article_id', 'embeddings')
embeddings

article_id,embeddings
i32,"array[f32, 128]"
9734738,"[0.263068, 0.135309, … -0.211787]"
8647636,"[0.274699, 0.283317, … -0.179495]"
9715678,"[0.044149, -0.111148, … -0.004841]"
4001699,"[0.301314, 0.272426, … -0.062496]"
4127160,"[0.326045, 0.207151, … -0.05491]"
…,…
5526970,"[0.34692, 0.222736, … 0.097438]"
4006630,"[0.325028, 0.244405, … 0.195393]"
7742851,"[0.189014, 0.07577, … -0.080563]"
4789449,"[0.282656, 0.001676, … 0.11607]"


In [130]:
embeddings.write_parquet('preprocess/image_embeddings.parquet')