In [1]:
import torch
import thunder
from NueMF import (
    NeuMF,
    MovieLensDataModule,
    compile_model
)

In [16]:
gmf_checkpoint = "/teamspace/studios/this_studio/RecSys/NueMF/checkpoints/gmf-epoch=09.ckpt.ckpt"
mlp_checkpoint = "/teamspace/studios/this_studio/RecSys/NueMF/checkpoints/mlp-epoch=02.ckpt.ckpt"
nuemf_checkpoint = "/teamspace/studios/this_studio/RecSys/NueMF/checkpoints/neumf-epoch=02-NeuMF_val_loss=0.00.ckpt"
num_users = 6040
num_items = 3706

with torch.device("cuda"):
    neumf_model = NeuMF.load_from_checkpoint(
        nuemf_checkpoint,
        num_users=num_users,
        num_items=num_items,
        gmf_checkpoint=gmf_checkpoint,
        mlp_checkpoint=mlp_checkpoint,
    ).requires_grad_(False).eval()
    dm = MovieLensDataModule(batch_size=4096)
    dm.setup("bench")
    val_loader = dm.val_dataloader()

2024-12-01 14:39:12,570 - NueMF - INFO - Generating test samples...


In [17]:
batch = next(iter(val_loader))
user_ids, item_ids, ratings = batch["user_id"], batch["item_id"], batch["rating"]

user_ids = user_ids.to("cuda")
item_ids = item_ids.to("cuda")
ratings = ratings.to("cuda")

In [18]:
with torch.device("cuda"):
    jit_neumf_model = thunder.jit(neumf_model)

jit_neumf_model(user_ids, item_ids)

tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]], device='cuda:0')

In [19]:
with torch.device("cuda"):
    torch_compiled = torch.compile(neumf_model)

torch_compiled(user_ids, item_ids)

tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]], device='cuda:0')

In [20]:
%timeit jit_neumf_model(user_ids, item_ids); torch.cuda.synchronize()
%timeit torch_compiled(user_ids, item_ids); torch.cuda.synchronize()
%timeit neumf_model(user_ids, item_ids); torch.cuda.synchronize()

6.77 ms ± 47.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.13 ms ± 21.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.83 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
thunder.last_traces(jit_neumf_model)[-1]

# Constructed by Unwrap the actual return value
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(user_ids, item_ids, t_gmf_item_embedding_weight, t_gmf_user_embedding_weight, t_mlp_item_embedding_weight, t_mlp_mlp_layers_0_bias, t_mlp_mlp_layers_0_weight, t_mlp_mlp_layers_1_bias, t_mlp_mlp_layers_1_weight, t_mlp_mlp_layers_4_bias, t_mlp_mlp_layers_4_weight, t_mlp_mlp_layers_7_bias, t_mlp_mlp_layers_7_weight, t_mlp_user_embedding_weight, t_output_layer_bias, t_output_layer_weight):
  # user_ids: "cuda:0 i64[4096]"
  # item_ids: "cuda:0 i64[4096]"
  # t_gmf_item_embedding_weight: "cuda:0 f32[3706, 1024]"
  # t_gmf_user_embedding_weight: "cuda:0 f32[6040, 1024]"
  # t_mlp_item_embedding_weight: "cuda:0 f32[3706, 1024]"
  # t_mlp_mlp_layers_0_bias: "cuda:0 f32[1024]"
  # t_mlp_mlp_layers_0_weight: "cuda:0 f32[1024, 2048]"
  # t_mlp_mlp_layers_1_bias: "cuda:0 f32[512]"
  # t_mlp_mlp_layers_1_weight: "cud

In [22]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("thunder"):
        out = jit_neumf_model(user_ids, item_ids)
print(prof.key_averages().table(sort_by="cuda_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 volta_sgemm_128x128_tn         0.00%       0.000us         0.00%       0.000us       0.000us       2.477ms        53.44%       2.477ms       1.239ms             2  
                                  volta_sgemm_128x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us     858.594us        18.52%     858.594us     429.297us             2  
void at::

In [23]:
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("eager"):
        out = neumf_model(user_ids, item_ids)
print(prof.key_averages().table(sort_by="cuda_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 volta_sgemm_128x128_tn         0.00%       0.000us         0.00%       0.000us       0.000us       2.472ms        50.84%       2.472ms       1.236ms             2  
                                  volta_sgemm_128x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us     859.623us        17.68%     859.623us     429.811us             2  
void at::