In [1]:
#import os
#os.environ['HF_DATASETS_OFFLINE'] = "1"

from datasets import load_dataset
from torchvision.transforms import PILToTensor
from torchvision.transforms.functional import pil_to_tensor
import torchvision.transforms.functional as TFF
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, TrainerCallback

In [2]:
raw_dataset = load_dataset('fashion_mnist')

In [28]:
def collate_fn(example):
    return {'input_ids': torch.stack([TFF.to_tensor(x['image']).flatten() for x in example]), 
            'labels': torch.stack([torch.tensor(x['label'], dtype=torch.long) for x in example])}

def collate_fn2(example):
    #print ([TFF.to_tensor(x) for x in example['image']])
    example['input_ids'] = torch.stack([TFF.to_tensor(x).flatten() for x in example['image']])
    example['labels'] = torch.stack([torch.tensor(x) for x in example['label']])
    return example

    #return {'input_ids': torch.stack([TFF.to_tensor(x['image']).flatten() for x in example]), 
    #        'labels': torch.stack([torch.tensor(x['label'], dtype=torch.long) for x in example])}

ds_train = raw_dataset['train']
ds2 = ds_train.map(collate_fn2, batched=True, remove_columns=ds_train.column_names)

#train_dl = DataLoader(ds_train, batch_size=4, collate_fn=collate_fn)

train_dl = DataLoader(ds_train, batch_size=4, collate_fn=collate_fn)


In [29]:
next(iter(train_dl))

{'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'labels': tensor([9, 0, 0, 3])}

In [31]:
class MyModel(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.linear = nn.Linear(in_features=784, out_features=10)

    def forward(self, input_ids, labels):
        #print (input_ids)
        #print (input_ids.shape)
        logits = self.linear(input_ids)
        #print (logits)
        loss = F.cross_entropy(logits, labels)

        return (loss, logits)

In [32]:
batch = next(iter(train_dl))
batch

{'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'labels': tensor([9, 0, 0, 3])}

In [33]:
model = MyModel()
model(**batch)

(tensor(2.5039, grad_fn=<NllLossBackward0>),
 tensor([[-0.2855, -0.2894,  0.0510,  0.2150, -0.5922, -0.7458,  0.6226, -0.0336,
           0.5699,  0.0694],
         [-0.4961, -0.0712, -0.0452, -0.1188, -0.0774, -0.4007,  0.4912,  0.1807,
           0.6254, -0.1754],
         [-0.2056, -0.1007, -0.0309,  0.0482, -0.0984, -0.0907,  0.4011,  0.0223,
           0.1133, -0.1061],
         [-0.0697, -0.1477,  0.0651, -0.0383, -0.0787, -0.0981,  0.3791,  0.0193,
           0.2459, -0.1507]], grad_fn=<AddmmBackward0>))

In [34]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for step, batch in enumerate(train_dl):

    # Compute prediction error
    loss, logits = model(**batch)
    #loss = loss_fn(pred, y)

    # Backpropagation
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 1000 == 0:
        print(f"step {step} loss: {loss.item():>7f}")

step 0 loss: 2.503943
step 1000 loss: 1.782980
step 2000 loss: 1.820226
step 3000 loss: 1.523874
step 4000 loss: 1.953312
step 5000 loss: 1.424843
step 6000 loss: 1.150937
step 7000 loss: 1.324925
step 8000 loss: 1.228642
step 9000 loss: 0.832661
step 10000 loss: 0.829881
step 11000 loss: 0.862198
step 12000 loss: 0.893969
step 13000 loss: 0.714225
step 14000 loss: 0.566925


In [46]:

class AddToLog(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        #_ = logs.pop("total_flos", None)
        #if state.is_local_process_zero:
        #    print(logs)
        #print ('args:')
        #print (args)
        #print ('state:')
        logs['step'] = state.global_step
        #print (state)
        #print ('logs:')
        #print (logs)
        #sys.exit()

args = TrainingArguments(output_dir='test_train', logging_steps=1000, do_eval=False)

trainer = Trainer(model=model, args= args, train_dataset=ds2, callbacks=[AddToLog()])

In [47]:
trainer.train()

  0%|          | 0/22500 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [45]:
ds_train

Dataset({
    features: ['image', 'label'],
    num_rows: 60000
})