In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from adafactor_pytorch import Adafactor
from torch.optim import AdamW, Adam
from lion_pytorch import Lion
import torchvision
from datasets import load_dataset
from torchvision import transforms
import torch
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [3]:
train_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)
dataset = load_dataset("mnist")
dataset = dataset["train"]
def map_label2one_hot(label):
    out = np.zeros(10)
    out[label] = 1
    return out
def transform_func(examples):
    examples["image"] = [train_transforms(img) for img in examples["image"]]
    examples["label"] = [torch.tensor(map_label2one_hot(label)) for label in examples["label"]]
    return examples
dataset = dataset.with_transform(transform_func)

def collate_fn(examples):
    pixel_values = torch.stack([example["image"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    labels = torch.stack([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=128,
    num_workers=0,
)

In [4]:
model = torch.nn.Sequential(torch.nn.Linear(28*28, 10), torch.nn.Softmax(dim=1)).to(device)
optimizer = Adafactor(model.parameters(), betas=(0.9, 0.99), lr=1e-4, use_triton=True)
ce_loss = torch.nn.CrossEntropyLoss()
losses = []
for _ in range(3):
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        batch_size = pixel_values.shape[0]
        pixel_values = pixel_values.reshape((batch_size, -1))
        predicted = model(pixel_values)
        loss = ce_loss(predicted, labels)
        losses.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
    print(losses[-1])
plt.plot(losses)

  0%|          | 0/469 [00:00<?, ?it/s]

torch.Size([10, 784])


/tmp/tmppx7ab4f2/main.c:4:10: fatal error: Python.h: No such file or directory
 #include <Python.h>
          ^~~~~~~~~~
compilation terminated.


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "<string>", line 21, in matrix_update_fn_kernel
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, 'fp32', 'fp32', 'fp32', 'fp32', 'fp32', 'fp32', 'i1', 'fp32', 'i32', 'i32', 'i32'), (128,), (True, True, True, True, True, (False,), (False,), (False,), (False,), (False,), (False,), (True, False), (False,), (True, False), (False, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/isamu/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipyk

In [None]:
model = torch.nn.Sequential(torch.nn.Linear(28*28, 10), torch.nn.Softmax(dim=1)).to(device)
optimizer = Adafactor(model.parameters(), betas=(0.9, 0.99), lr=1e-4, relative_step=False, scale_parameter=False)
ce_loss = torch.nn.CrossEntropyLoss()

In [None]:
losses = []
for _ in range(3):
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        batch_size = pixel_values.shape[0]
        pixel_values = pixel_values.reshape((batch_size, -1))
        predicted = model(pixel_values)
        loss = ce_loss(predicted, labels)
        losses.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
    print(losses[-1])
    

In [None]:
plt.plot(losses)

Test AdamW

In [None]:
model = torch.nn.Sequential(torch.nn.Linear(28*28, 10), torch.nn.Softmax(dim=1)).to(device)
optimizer = AdamW(model.parameters(), betas=(0.9, 0.99), lr=1e-4)
ce_loss = torch.nn.CrossEntropyLoss()

In [None]:
losses = []
for _ in range(3):
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        batch_size = pixel_values.shape[0]
        pixel_values = pixel_values.reshape((batch_size, -1))
        predicted = model(pixel_values)
        loss = ce_loss(predicted, labels)
        losses.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
    print(losses[-1])
    

In [None]:
plt.plot(losses)

In [None]:
model = torch.nn.Sequential(torch.nn.Linear(28*28, 10), torch.nn.Softmax(dim=1)).to(device)
optimizer = Lion(model.parameters(), betas=(0.9, 0.99), lr=1e-5)
ce_loss = torch.nn.CrossEntropyLoss()

In [None]:
losses = []
for _ in range(3):
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        batch_size = pixel_values.shape[0]
        pixel_values = pixel_values.reshape((batch_size, -1))
        predicted = model(pixel_values)
        loss = ce_loss(predicted, labels)
        losses.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
    print(losses[-1])
    

In [None]:
plt.plot(losses)