In [1]:
import pandas as pd 
import torch

df = pd.read_parquet('data/merged_data.parquet')
df.head()

Unnamed: 0,customer_id,article_id,price,purchase_flag,age,postal_code,product_code,prod_name,product_type_no,product_type_name,...,department_name,index_code,index_name,index_group_no,index_group_name,section_no,section_name,garment_group_no,garment_group_name,detail_desc
0,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,663713001,0.050831,1,24.0,64f17e6a330a85798e4998f62d0930d14db8db1c054af6...,663713,Atlanta Push Body Harlow,283,Underwear body,...,Expressive Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Lace push-up body with underwired, moulded, pa..."
1,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,541518023,0.030492,1,24.0,64f17e6a330a85798e4998f62d0930d14db8db1c054af6...,541518,Rae Push (Melbourne) 2p,306,Bra,...,Casual Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Lace push-up bras with underwired, moulded, pa..."
2,00007d2de826758b65a93dd24ce629ed66842531df6699...,505221004,0.015237,1,32.0,8d6f45050876d059c830a0fe63f1a4c022de279bb68ce3...,505221,Inca Jumper,252,Sweater,...,Tops Knitwear DS,D,Divided,2,Divided,58,Divided Selected,1003,Knitwear,Jumper in rib-knit cotton with hard-worn detai...
3,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687003,0.016932,1,32.0,8d6f45050876d059c830a0fe63f1a4c022de279bb68ce3...,685687,W YODA KNIT OL OFFER,252,Sweater,...,Campaigns,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1023,Special Offers,V-neck knitted jumper with long sleeves and ri...
4,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687004,0.016932,1,32.0,8d6f45050876d059c830a0fe63f1a4c022de279bb68ce3...,685687,W YODA KNIT OL OFFER,252,Sweater,...,Campaigns,A,Ladieswear,1,Ladieswear,15,Womens Everyday Collection,1023,Special Offers,V-neck knitted jumper with long sleeves and ri...


In [2]:
fact_customer = pd.factorize(df['customer_id'])
fact_article = pd.factorize(df['article_id'])

In [12]:
import torch.nn as nn
import torch.optim as optim

class ALSModel(nn.Module):
    def __init__(self, num_users, num_items, factors=50, reg=0.1):
        super(ALSModel, self).__init__()
        self.user_factors = nn.Parameter(torch.randn(num_users, factors) * 0.01)
        self.item_factors = nn.Parameter(torch.randn(num_items, factors) * 0.01)
        self.reg = reg

    def forward(self, user_indices, item_indices):
        user_embeds = self.user_factors[user_indices]
        item_embeds = self.item_factors[item_indices]
        return (user_embeds * item_embeds).sum(dim=1)
    
    def loss(self, predictions, targets):
        mse = torch.mean((predictions - targets) ** 2)
        reg_loss = self.reg * (torch.sum(self.user_factors**2) + torch.sum(self.item_factors**2))
        return mse + reg_loss
    
num_users = len(fact_customer[1])
num_items = len(fact_article[1])
factors=50
reg=0.175

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = ALSModel(num_users, num_items, factors, reg).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

user_indices = torch.randint(0, num_users, (100000,), device=device)
item_indices = torch.randint(0, num_items, (100000,), device=device)
ratings = torch.rand(100000, device=device)

for epoch in range(20):
    model.train()
    optimizer.zero_grad()

    predictions = model(user_indices, item_indices)
    loss = model.loss(predictions, ratings)

    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {loss.item()}")

Epoch 1: Loss = 1283.551513671875
Epoch 2: Loss = 1091.602294921875
Epoch 3: Loss = 923.7034912109375
Epoch 4: Loss = 778.5209350585938
Epoch 5: Loss = 653.806640625
Epoch 6: Loss = 547.248779296875
Epoch 7: Loss = 456.8792419433594
Epoch 8: Loss = 380.842529296875
Epoch 9: Loss = 317.286865234375
Epoch 10: Loss = 264.41876220703125
Epoch 11: Loss = 220.65341186523438
Epoch 12: Loss = 184.63694763183594
Epoch 13: Loss = 155.170166015625
Epoch 14: Loss = 131.15481567382812
Epoch 15: Loss = 111.60862731933594
Epoch 16: Loss = 95.70573425292969
Epoch 17: Loss = 82.77359008789062
Epoch 18: Loss = 72.25118255615234
Epoch 19: Loss = 63.654972076416016
Epoch 20: Loss = 56.57658004760742
