## Import Library

In [2]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

## Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

  set_matplotlib_formats('svg', 'pdf') # For export
  from .autonotebook import tqdm as notebook_tqdm


### Dataset Path
Mohon sesuaikan path dengan preferensi anda

In [3]:
DATASET_PATH = "../data"
CHECKPOINT_PATH = "../checkpoints"

### Konfigurasi Lainnya

In [4]:
# random seed untuk pytorch lightning
pl.seed_everything(42)

# konfigurasi GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Global seed set to 42


Device: cpu


### Load Model
Model bersumber dari [https://github.com/phlippe/saved_models/tree/main/tutorial6](https://github.com/phlippe/saved_models/tree/main/tutorial6)

In [5]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/"
# Files to download
pretrained_files = ["ReverseTask.ckpt", "SetAnomalyTask.ckpt"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/ReverseTask.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/SetAnomalyTask.ckpt...


Fungsi `scaled_dot_product`

In [6]:
def scaled_dot_product(q, k, v, mask=None):
    '''
    d_k adalah shape dari q pada dimensi terakhir
    '''
    d_k = q.size()[-1]
    
    '''
    attn_logits didapat dari mengalikan q dengan transpose dimensi terakhir dan dimensi kedua dari terakhir k
    lalu attn_logits dibagi dengan akar dari d_k
    '''
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    
    
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

### Coret-coretan

In [59]:
k = torch.randn(3, 2, 4)
print(k)
k = k.transpose(-1,-2)
print(k)

tensor([[[ 2.2516, -0.3681, -1.0801,  1.9882],
         [ 0.1035,  0.1166, -0.0252,  0.3789]],

        [[ 2.1100,  0.7726, -0.3427,  1.2370],
         [ 1.1887,  1.0021, -1.4647, -0.0302]],

        [[ 0.3997,  0.6982,  0.0521,  0.2882],
         [ 0.3520, -0.3862,  0.3568,  1.7114]]])
tensor([[[ 2.2516,  0.1035],
         [-0.3681,  0.1166],
         [-1.0801, -0.0252],
         [ 1.9882,  0.3789]],

        [[ 2.1100,  1.1887],
         [ 0.7726,  1.0021],
         [-0.3427, -1.4647],
         [ 1.2370, -0.0302]],

        [[ 0.3997,  0.3520],
         [ 0.6982, -0.3862],
         [ 0.0521,  0.3568],
         [ 0.2882,  1.7114]]])


Referensi:
- https://www.youtube.com/watch?v=hGZ6wa07Vak&list=PLdlPlO1QhMiAkedeu0aJixfkknLRxk1nA&index=11