<a href="https://colab.research.google.com/github/ihatethinkingname/dl_practice/blob/main/sasrec_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

import copy
import math
import os

In [None]:
# Transformer practice part

class EnD(nn.Module):
  def __init__(self,encoder,decoder,src_embed,tgt_embed,generator):
    super().__init__()
    self.encoder=encoder
    self.decoder=decoder
    self.src_embed=src_embed
    self.tgt_embed=tgt_embed
    self.generator=generator

  def forward(self,src,tgt,src_mask,tgt_mask):
    return self.decoder(self.encoder(src,src_mask),src_mask,tgt,tgt_mask)

  def encode(self,src,src_mask):
    return self.encoder(self.src_embed(src),src_mask)

  def decode(self,memory,src_mask,tgt,tgt_mask):
    return self.decoder(memory,src_mask,self.tgt_embed(tgt),tgt_mask)

class Embeddings(nn.Module):
  def __init__(self,vocab_size,d_model):
    super().__init__()
    self.lut=nn.Embedding(vocab_size,d_model,padding_idx=0)
    self.d_model=d_model

  def forward(self,x):
    return self.lut(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
  def __init__(self,d_model,max_len=5000,dropout_rate=0.1):
    super().__init__()
    self.pos_embedding=nn.Embedding(max_len,d_model)
    self.dropout=nn.Dropout(dropout_rate)

  def forward(self,x,zero_padding=True):
    batch_size,seq_len=x.shape[:,2]
    device=x.device
    pos_idx=torch.arange(0,seq_len,device=device).unsqueeze(0).expand(batch_size,seq_len)
    if zero_padding:
      pad_mask=x.abs().sum(dim=-1)!=0
      pos_idx*=pad_mask.long()
    x = x + self.pos_embedding(pos_idx)
    return self.dropout(x)

class Generator(nn.Module):
  def __init__(self,d_model,vocab_size):
    super().__init__()
    self.proj=nn.Linear(d_model,vocab_size)

  def forward(self,x):
    return F.log_softmax(self.proj(x),dim=-1)


class EnDcoder(nn.Module):
  def __init__(self,layer,N):
    super().__init__()
    self.layers=clone(layer,N)
    self.norm=nn.LayerNorm(layer.d_model)

  def forward(self,x,src_mask,tgt=None,tgt_mask=None):
    for layer in self.layers:
      x=layer(x,src_mask,tgt,tgt_mask)
    return self.norm(x)

def clone(layer,N):
  return nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])

class EnDcoderLayer(nn.Module):
  def __init__(self,self_attn,ffn,d_model,dropout_rate,src_attn=None):
    super().__init__()
    self.self_attn=self_attn
    self.ffn=ffn
    self.connector=clone(SublayerConnection(d_model,dropout_rate),2)
    self.d_model=d_model
    if src_attn is not None:
      self.src_attn=src_attn
      self.connector.append(SublayerConnection(d_model,dropout_rate))

  def forward(self,x,src_mask,tgt=None,tgt_mask=None):
    if isEncoder(tgt,tgt_mask):
      x=self.connector[0](x,lambda x: self.self_attn(x,x,x,src_mask))
    else:
      m=x         # m is memoery
      x=tgt        # We keep using x as the thing go through the network
      x=self.connector[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
      x=self.connector[1](x,lambda x: self.src_attn(x,m,m,src_mask))
    return self.connector[-1](x,self.ffn)

def isEncoder(tgt,tgt_mask):
    # Encoder model
    if tgt is None and tgt_mask is None:
      return True
    # Decoder model
    elif tgt is not None and tgt_mask is not None:
      return False
    # Wrong input
    else:
      raise ValueError("Encoder instance input error: "\
               "tgt and tgt_mask should be both None or both not None.")

class SublayerConnection(nn.Module):
  def __init__(self,d_model,dropout_rate):
    super().__init__()
    self.norm=nn.LayerNorm(d_model)
    self.dropout=nn.Dropout(dropout_rate)

  def forward(self,x,sublayer):
    return x+self.dropout(sublayer(self.norm(x)))

def attention(q,k,v,mask=None,dropout=None):
  d_k=q.shape[-1]
  scores=torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)
  if mask is not None:
    scores=scores.masked_fill(mask==0,-1e9)
  p_attn=F.softmax(scores,dim=-1)
  if dropout is not None:
    p_attn=dropout(p_attn)
  return torch.matmul(p_attn,v),p_attn

class MultiHeadedAttention(nn.Module):
  def __init__(self,n_head,d_model,dropout_rate=0.1):
    super().__init__()
    assert d_model%n_head==0
    self.d_head=d_model//n_head
    self.n_head=n_head
    self.linears=clone(nn.Linear(d_model,d_model),4)
    self.p_attn=None
    self.dropout=nn.Dropout(p=dropout_rate)

  def forward(self,quary,key,value,mask=None):
    if mask is not None:
      mask=mask.unsqueeze(1)
    batch_size=quary.shape[0]
    q,k,v=[l(x).view(batch_size,-1,self.n_head,self.d_head).transpose(1,2)
        for l,x in zip(self.linears,(quary,key,value))]
    x,self.p_attn=attention(q,k,v,mask=mask,dropout=self.dropout)
    x=x.transpose(1,2).contiguous().view(batch_size,-1,self.n_head*self.d_head)
    return self.linears[-1](x)

class FFN(nn.Module):
  def __init__(self,d_model,d_ff,dropout_rate=0.1):
    super().__init__()
    self.L1=nn.Linear(d_model,d_ff)
    self.dropout=nn.Dropout(0.1)
    self.L2=nn.Linear(d_ff,d_model)

  def forward(self,x):
    return self.L2(self.dropout(F.relu(self.L1(x))))

def causal_mask(seq_len,device):
  # causal_mask=~torch.tril(torch.ones((seq_len,seq_len),dtype=torch.float,device=self.device))
  i=torch.arange(seq_len,device=device)
  return (i[:,None])>=(i[None,:])

def make_model(src_vocab,tgt_vocab,d_model=512,n_sublayer=6,d_ff=2048,n_head=8,dropout_rate=0.1):
  c=copy.deepcopy

  src_embed=Embeddings(src_vocab,d_model)
  tgt_embed=Embeddings(tgt_vocab,d_model)
  generator=Generator(d_model,tgt_vocab)
  pos_embed=PositionalEncoding(d_model)

  attn=MultiHeadedAttention(n_head,d_model,dropout_rate)
  ffn=FFN(d_model,d_ff,dropout_rate)
  encoder=EnDcoder(EnDcoderLayer(c(attn),c(ffn),d_model,dropout_rate),n_sublayer)
  decoder=EnDcoder(EnDcoderLayer(c(attn),c(ffn),d_model,dropout_rate,c(attn)),n_sublayer)

  model=EnD(encoder,decoder,nn.Sequential(c(src_embed),c(pos_embed)),
            nn.Sequential(c(tgt_embed),c(pos_embed)),generator)

  for p in model.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)

  return model


In [None]:
# SASRec practice part

class SASRec(nn.Module):
  def __init__(self,n_user,n_item,args):
    self.n_user=n_user
    self.n_item=n_item
    self.device=args.device
    d_model=args.d_model

    self.item_emb=Embeddings(n_item,d_model)
    self.pos_encode=PositionalEncoding(d_model,args.max_len,args.dropout_rate)

    self.self_attn=MultiHeadedAttention(args.n_head,d_model,args.dropout_rate)
    self.ffn=FFN(d_model,args.d_ff,args.dropout_rate)
    self.encoder=EnDcoder(EnDcoderLayer(self.self_attn,self.ffn,d_model,args.dropout_rate),args.n_sublayer)


  def log2feat(self,log_seq):  # Is log_seq a tensor?
    if not torch.is_tensor(log_seq):
      log_seq = torch.as_tensor(log_seq)
    log_seq = log_seq.to(self.device, dtype=torch.long)
    pad_mask=(log_seq!=0).unsqueeze(1).unsqueeze(2)

    log_seq=self.pos_encode(self.item_emb(log_seq))

    seq_len=log_seq.shape[1]
    causal_mask=causal_mask(seq_len,self.device).unsqueeze(0).unsqueeze(1)
    combined_mask=pad_mask & causal_mask
    feat_seq=self.encoder(log_seq,combined_mask)

    return feat_seq

  def forward(self,log_seq,pos_seq,neg_seq):  # Are pos_seq and neg_seq float?
    feat_seq=self.log2feat(log_seq)

    if not torch.is_tensor(pos_seq):
        pos_seq = torch.as_tensor(pos_seq)
    if not torch.is_tensor(neg_seq):
        neg_seq = torch.as_tensor(neg_seq)
    pos_seq_embed=self.item_emb(pos_seq.to(dtype=torch.long, device=self.device))
    neg_seq_embed=self.item_emb(neg_seq.to(dtype=torch.long, device=self.device))

    pos_logits=(feat_seq*pos_seq_embed).sum(dim=-1)
    neg_logits=(feat_seq*neg_seq_embed).sum(dim=-1)

    return pos_logits,neg_logits

  def predict(self,log_seq,item_indices):
    feat_seq=self.log2feat(log_seq)[:,-1,:]

    if not torch.is_tensor(item_indices):
      item_indices = torch.as_tensor(item_indices)
    item_indices=item_indices.to(dtype=torch.long,device=self.device)

    item_embed=self.item_emb(item_indices)

    logits=item_embed.matmul(feat_seq.unsqueeze(-1)).squeeze(-1)

    return logits

In [None]:
a=[1,2]


In [None]:
# SASRec use ml-1m

# 1. 下载
if not os.path.exists("ml-1m.zip"):
    !wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
else:
    print("✅ 已存在 ml-1m.zip，跳过下载")

# 2. 解压
if not os.path.exists("ml-1m"):
    !unzip -q ml-1m.zip
    print("✅ 解压完成")
else:
    print("✅ 已存在 ml-1m 文件夹，跳过解压")

# 2. 读取 ratings.dat
ratings = pd.read_csv(
    "ml-1m/ratings.dat",
    sep="::",
    engine="python",
    names=["UserID", "MovieID", "Rating", "Timestamp"],
    encoding="latin-1"
)

# 3. 读取 movies.dat
movies = pd.read_csv(
    "ml-1m/movies.dat",
    sep="::",
    engine="python",
    names=["MovieID", "Title", "Genres"],
    encoding="latin-1"
)

# 4. 读取 users.dat
users = pd.read_csv(
    "ml-1m/users.dat",
    sep="::",
    engine="python",
    names=["UserID", "Gender", "Age", "Occupation", "Zip-code"],
    encoding="latin-1"
)

# 5. 数据简单查看
print("📌 Ratings 样例:")
print(ratings.head(), "\n")

print("📌 Movies 样例:")
print(movies.head(), "\n")

print("📌 Users 样例:")
print(users.head(), "\n")

# 6. 统计信息
print(f"用户数量: {ratings['UserID'].nunique()}")
print(f"电影数量: {ratings['MovieID'].nunique()}")
print(f"评分数量: {len(ratings)}")
print(f"评分范围: {ratings['Rating'].min()} ~ {ratings['Rating'].max()}")

# 7. 合并成一个大表（方便分析）
df = ratings.merge(movies, on="MovieID").merge(users, on="UserID")
print("\n📌 合并后的样例:")
print(df.head())


✅ 已存在 ml-1m.zip，跳过下载
✅ 已存在 ml-1m 文件夹，跳过解压
📌 Ratings 样例:
   UserID  MovieID  Rating  Timestamp
0       1     1193       5  978300760
1       1      661       3  978302109
2       1      914       3  978301968
3       1     3408       4  978300275
4       1     2355       5  978824291 

📌 Movies 样例:
   MovieID                               Title                        Genres
0        1                    Toy Story (1995)   Animation|Children's|Comedy
1        2                      Jumanji (1995)  Adventure|Children's|Fantasy
2        3             Grumpier Old Men (1995)                Comedy|Romance
3        4            Waiting to Exhale (1995)                  Comedy|Drama
4        5  Father of the Bride Part II (1995)                        Comedy 

📌 Users 样例:
   UserID Gender  Age  Occupation Zip-code
0       1      F    1          10    48067
1       2      M   56          16    70072
2       3      M   25          15    55117
3       4      M   45           7    02460
4       5 