In [3]:
import torch
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from dataset.mnist import get_dataloader
from model.resnet import get_resnet

def train(total_epoch: int=100):
   writer = SummaryWriter(log_dir="log")
   dataloader = get_dataloader(root="data", batch_size=64)

   model = get_resnet(pretrained=True)
   optimizer = optim.SGD(
       params=model.parameters(),
       lr=1e-3
   )
   scheduler = optim.lr_scheduler.OneCycleLR(
       optimizer=optimizer,
       max_lr=1e-3,
       total_steps=len(dataloader),
   )
   criterion = nn.CrossEntropyLoss()

   model.train()
   for epoch in range(total_epoch):
       accuracy, train_loss = 0.0, 0.0
       for images, labels in tqdm(dataloader):
           optimizer.zero_grad()
           
           out = model(images)
           loss = criterion(out, labels)

           loss.backward()
           optimizer.step()

           # 推測値
           preds = out.argmax(axis=1)

           # lossの算出
           train_loss += loss.item()
           accuracy += torch.sum(preds == labels).item() / len(labels)

       scheduler.step()

       # logの記録
       writer.add_scalar("loss", train_loss / len(dataloader), epoch)
       writer.add_scalar("accuracy", accuracy / len(dataloader), epoch)     
       writer.add_scalar("lr", scheduler.get_lr()[0], epoch)   

       print(f"epoch: {epoch + 1}")
       print(f"loss: {train_loss / len(dataloader)}")
       print(f"accuracy: {accuracy / len(dataloader)}")



ModuleNotFoundError: No module named 'dataset'