In [179]:
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.Normalize((0), (1,)), transforms.ToTensor()]) #transforms.Resize((784, 1))])


In [184]:
import torchvision.datasets as datasets
mnist_trainset = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=None)

In [188]:
imgs, labels = mnist_trainset.data, mnist_trainset.targets
imgs = imgs.reshape(imgs.shape[0], imgs.shape[1] * imgs.shape[2])

imgs = imgs.ca('float32') / 255

AttributeError: 'Tensor' object has no attribute 'astype'

In [186]:
imgs = imgs[:784]
labels = labels[:784]

In [187]:
imgs = imgs.transpose(1, 0)

In [138]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import Trainer

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_0 = nn.Linear(784, 128)
        self.linear_1 = nn.Linear(128, 10)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.squeeze(-1)
        x = torch.relu(self.linear_0(x))
        x = self.linear_1(x)
        
        return x

In [139]:
dataset = TensorDataset(imgs, labels)

In [140]:
data_loader = DataLoader(dataset,
                        batch_size=32,
                        shuffle=True)

In [141]:
from transformers import Trainer, TrainingArguments

In [159]:
model = SimpleNet()

In [170]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [174]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(data_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()
        print(inputs.shape, labels.shape)
        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(data_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [175]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

In [176]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    # for i, vdata in enumerate(validation_loader):
    #     vinputs, vlabels = vdata
    #     voutputs = model(vinputs)
    #     vloss = loss_fn(voutputs, vlabels)
    #     running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
torch.Size([32, 784]) torch.Size([32])


RuntimeError: expected scalar type Float but found Byte