In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import wandb

# Khởi tạo W&B
wandb.init(project="mnist-model-with-artifact")

# Tải MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Mô hình đơn giản (2 lớp fully connected)
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Mỗi ảnh MNIST có kích thước 28x28
        self.fc2 = nn.Linear(128, 10)  # 10 classes (0-9)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten ảnh thành vector
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = MNISTModel()

# Loss function và optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Huấn luyện mô hình
for epoch in range(5):  # Train trong 5 epochs
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss}")
    wandb.log({"epoch": epoch + 1, "loss": avg_loss})

# Lưu mô hình sau khi huấn luyện
torch.save(model.state_dict(), "mnist_model.pth")

# Tạo artifact để lưu mô hình
model_artifact = wandb.Artifact('mnist_model', type='model')
model_artifact.add_file('mnist_model.pth')

# Log artifact vào W&B
wandb.log_artifact(model_artifact)

# Kết thúc phiên W&B
wandb.finish()


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfire-kidboy1505[0m ([33mfire-kidboy1505-vietnam-national-university-hanoi[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1, Loss: 0.38141130430421344
Epoch 2, Loss: 0.20243060422628356
Epoch 3, Loss: 0.14681921711664148
Epoch 4, Loss: 0.11722750216474268
Epoch 5, Loss: 0.09936532973429002


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▃▅▆█
loss,█▄▂▁▁

0,1
epoch,5.0
loss,0.09937


In [None]:
!tree artifacts