In [1]:
import torch
import yaml
from config import BitformerConfig
from model_zoo import VisionBitformerForImageClassification
from trainer import get_trainer
from data_zoo import get_mnist

In [2]:
yaml_path = './yamls/mnist.yaml'
with open(yaml_path, 'r') as file:
    args = yaml.safe_load(file)

In [3]:
args

{'general_config': {'model_type': 'VisionBitformerForImageClassification',
  'num_labels': 10},
 'model_config': {'vocab_size': 32000,
  'hidden_size': 28,
  'intermediate_size': 512,
  'num_hidden_layers': 1,
  'num_attention_heads': 2,
  'num_key_value_heads': 2,
  'hidden_act': 'silu',
  'max_position_embeddings': 4096,
  'initializer_range': 0.02,
  'rms_norm_eps': 1e-05,
  'use_cache': False,
  'pad_token_id': None,
  'bos_token_id': 1,
  'eos_token_id': 2,
  'tie_word_embeddings': False,
  'rope_theta': 1000000.0,
  'sliding_window': 4096,
  'attention_dropout': 0.0,
  'num_experts_per_tok': 2,
  'num_local_experts': 4,
  'output_router_logits': True,
  'router_aux_loss_coef': 0.001,
  'is_causal': False,
  'moe': True,
  'bitnet': False},
 'training_args': {'output_dir': './results',
  'logging_dir': './logs',
  'report_to': None,
  'evaluation_strategy': 'epoch',
  'per_device_train_batch_size': 64,
  'per_device_eval_batch_size': 64,
  'gradient_accumulation_steps': 1,
  'lear

In [4]:
cfg = BitformerConfig(**args['model_config'], num_labels=args['general_config']['num_labels'])
model = VisionBitformerForImageClassification(config=cfg)
model.num_parameters() / 1e6

1.071644

In [5]:
from torchvision import transforms, datasets


def data_collator(batch):
    pixel_values = torch.stack([item[0].squeeze(0).reshape(28, 28) for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    return {'inputs_embeds': pixel_values, 'labels': labels}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./mnist_data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./mnist_data', train=False, download=True, transform=transform)

In [6]:
trainer = get_trainer(model, train_dataset, test_dataset, data_collator, **args['training_args'])

In [7]:
trainer.train()
trainer.evaluate()

  0%|          | 0/9380 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 1.8517, 'grad_norm': 2.1099369525909424, 'learning_rate': 0.0004977113991447017, 'epoch': 0.53}


  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 765    0    7    0    6   25   37    1  134    5]
 [   0 1115    5    2    0    4    6    0    3    0]
 [  20    0  675  174   27   54   64    4   14    0]
 [   4   17   97  745   10   42    4   51   17   23]
 [  41    0   13   16  673    7   22   13   28  169]
 [  41   10   87  173   25  360   29   20  135   12]
 [  51   14   31    0    6   13  832    0   11    0]
 [   0   14   14   53   10    1    0  838    3   95]
 [  68   17    4   21   22   27   22    2  766   25]
 [  12    8    1   19   73    5    0  134   22  735]]
{'eval_loss': 0.7526538968086243, 'eval_f1': 0.7456515832870662, 'eval_precision': 0.7501509712491882, 'eval_recall': 0.7504, 'eval_accuracy': 0.7504, 'eval_runtime': 2.1495, 'eval_samples_per_second': 4652.276, 'eval_steps_per_second': 73.041, 'epoch': 1.0}
{'loss': 0.9053, 'grad_norm': 1.7508418560028076, 'learning_rate': 0.0004884857252366847, 'epoch': 1.07}
{'loss': 0.6585, 'grad_norm': 2.313500165939331, 'learning_rate': 0.0004724434323532821

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 825    0    3    0    0   31   27    3   89    2]
 [   0 1112    9    1    0    3    4    0    6    0]
 [   6    7  827  101    3   55   20    4    8    1]
 [   0   14   58  871    2   18    0   31    9    7]
 [   9    4   17    8  773    5    5   21   10  130]
 [  10    7   42  142    8  623   13    3   36    8]
 [  28   12   34    0    7   20  846    1   10    0]
 [   0   13   15   14    8    2    0  919    2   55]
 [  42   12    4   21   13   67   18    4  785    8]
 [   5    5    0   18   44   21    0   68    8  840]]
{'eval_loss': 0.4991087317466736, 'eval_f1': 0.8419552749689246, 'eval_precision': 0.8439905647740618, 'eval_recall': 0.8421, 'eval_accuracy': 0.8421, 'eval_runtime': 2.1488, 'eval_samples_per_second': 4653.858, 'eval_steps_per_second': 73.066, 'epoch': 2.0}
{'loss': 0.5284, 'grad_norm': 3.591857433319092, 'learning_rate': 0.00045004305610692587, 'epoch': 2.13}
{'loss': 0.4514, 'grad_norm': 3.017133951187134, 'learning_rate': 0.0004219248647133558

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 902    0    1    1    1    5   32    2   34    2]
 [   0 1105    6    2    0    5   13    0    4    0]
 [  10    2  876   39    5   47   40    3    9    1]
 [   0    7   69  854    0   34    0   22   10   14]
 [   8    3   12    1  872    2    9    6    5   64]
 [   8    3   21   73    6  733   10    4   28    6]
 [  23    5    7    0    6   10  902    0    5    0]
 [   0   17   20   12   11    5    0  881    2   80]
 [  80   13    5   17   15   39   38    3  756    8]
 [   9    7    0   15   73   15    0   32   11  847]]
{'eval_loss': 0.3913519084453583, 'eval_f1': 0.8723047395892015, 'eval_precision': 0.8731715651243441, 'eval_recall': 0.8728, 'eval_accuracy': 0.8728, 'eval_runtime': 2.198, 'eval_samples_per_second': 4549.644, 'eval_steps_per_second': 71.429, 'epoch': 3.0}
{'loss': 0.4087, 'grad_norm': 3.228550434112549, 'learning_rate': 0.00038889255825490053, 'epoch': 3.2}
{'loss': 0.3771, 'grad_norm': 3.2687485218048096, 'learning_rate': 0.00035189029658340025

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 884    0    3    0    2    4   31    1   48    7]
 [   0 1105   10    4    0    3    7    0    6    0]
 [   6    2  863   97    6   17   22    6   10    3]
 [   0    7   27  933    0    9    0   19    5   10]
 [   2    3    4    1  884    0    4    9    5   70]
 [   2    7   17  130   10  660   11    5   39   11]
 [   5    4   12    1   15    4  906    2    9    0]
 [   0    9   13   11    8    1    0  921    0   65]
 [  27   12    9   22   15   16   18    5  832   18]
 [   3    7    3   11   40    3    0   31   12  899]]
{'eval_loss': 0.3567127287387848, 'eval_f1': 0.888568893545601, 'eval_precision': 0.892290807495399, 'eval_recall': 0.8887, 'eval_accuracy': 0.8887, 'eval_runtime': 2.391, 'eval_samples_per_second': 4182.27, 'eval_steps_per_second': 65.662, 'epoch': 4.0}
{'loss': 0.3565, 'grad_norm': 1.8129571676254272, 'learning_rate': 0.00031197571247243834, 'epoch': 4.26}
{'loss': 0.3294, 'grad_norm': 3.4569573402404785, 'learning_rate': 0.00027028968138185784,

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 942    0    1    0    1    4    9    0   19    4]
 [   0 1102    7    1    0    4    7    0   14    0]
 [  14    2  849   91    3   20   24   11   13    5]
 [   0    4   20  930    0   15    0   18   15    8]
 [   3    4    5    0  874    0    2    9    9   76]
 [   3    3    8   69    3  751    5    5   36    9]
 [  14    4   14    1   13   11  889    1   10    1]
 [   0   11   15   11   11    1    0  935    1   43]
 [  31    8    7   22    9   19   10    3  857    8]
 [   7    8    1   10   35    5    0   36   17  890]]
{'eval_loss': 0.3156144917011261, 'eval_f1': 0.9018494491308952, 'eval_precision': 0.903299262243816, 'eval_recall': 0.9019, 'eval_accuracy': 0.9019, 'eval_runtime': 2.1664, 'eval_samples_per_second': 4615.988, 'eval_steps_per_second': 72.471, 'epoch': 5.0}
{'loss': 0.318, 'grad_norm': 3.1098461151123047, 'learning_rate': 0.00022802371190303695, 'epoch': 5.33}
{'loss': 0.3066, 'grad_norm': 2.346620559692383, 'learning_rate': 0.00018638588896129557

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 922    0    2    1    0    4   37    0   11    3]
 [   0 1110    6    1    0    4    6    0    8    0]
 [   7    2  899   48    3   25   33    5   10    0]
 [   1    9   56  879    0   34    0   14   15    2]
 [   2    4    6    0  905    4   10    8    8   35]
 [   2    6   14   40    2  790   11    2   21    4]
 [   7    4    7    1    2    4  927    2    4    0]
 [   0   20   19   12   15    4    0  915    4   39]
 [  32   12    8   19    6   32   29    2  830    4]
 [   7   10    6   15   57   10    0   36   24  844]]
{'eval_loss': 0.304097056388855, 'eval_f1': 0.9017282130790071, 'eval_precision': 0.9022932979024486, 'eval_recall': 0.9021, 'eval_accuracy': 0.9021, 'eval_runtime': 2.1704, 'eval_samples_per_second': 4607.478, 'eval_steps_per_second': 72.337, 'epoch': 6.0}
{'loss': 0.291, 'grad_norm': 1.29850435256958, 'learning_rate': 0.0001465663432182349, 'epoch': 6.4}
{'loss': 0.2861, 'grad_norm': 2.469521999359131, 'learning_rate': 0.00010970323365940444, 'e

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 953    0    3    1    0    3    7    0   10    3]
 [   0 1105    6    1    0    4    6    0   12    1]
 [   6    5  904   58    3   15   20    6   12    3]
 [   0    5   42  924    0   15    0   11    9    4]
 [   2    3    4    1  894    1    3   11    8   55]
 [   2    4   15   43    3  784    6    2   26    7]
 [  13    4   11    1    9    9  900    2    9    0]
 [   0   17   15   15   12    2    0  922    1   44]
 [  26   10    9   22    7   21    7    3  863    6]
 [   6   11    6   19   29    5    0   27   14  892]]
{'eval_loss': 0.2785503566265106, 'eval_f1': 0.9140589032462638, 'eval_precision': 0.9145182921052126, 'eval_recall': 0.9141, 'eval_accuracy': 0.9141, 'eval_runtime': 2.2112, 'eval_samples_per_second': 4522.393, 'eval_steps_per_second': 71.002, 'epoch': 7.0}
{'loss': 0.2864, 'grad_norm': 1.9672399759292603, 'learning_rate': 7.685021568435078e-05, 'epoch': 7.46}
{'loss': 0.2665, 'grad_norm': 2.6236050128936768, 'learning_rate': 4.894632455610773e-0

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 942    0    2    1    0    6   15    0   11    3]
 [   0 1107    7    1    0    4    6    0    9    1]
 [   6    3  894   64    5   18   20    8   11    3]
 [   0    3   35  929    0   16    0   13   10    4]
 [   0    2    4    1  893    2    3   10    9   58]
 [   2    5    7   49    3  787    5    3   24    7]
 [   6    4   10    1    9    9  908    2    9    0]
 [   0   14   16   12   11    4    0  927    1   43]
 [  20    8    9   23    6   25    7    3  866    7]
 [   4    9    6   12   30    5    0   32   15  896]]
{'eval_loss': 0.269232839345932, 'eval_f1': 0.9149048708630928, 'eval_precision': 0.9153808479189037, 'eval_recall': 0.9149, 'eval_accuracy': 0.9149, 'eval_runtime': 2.2408, 'eval_samples_per_second': 4462.758, 'eval_steps_per_second': 70.065, 'epoch': 8.0}
{'loss': 0.2616, 'grad_norm': 2.2810122966766357, 'learning_rate': 2.6789135029152173e-05, 'epoch': 8.53}


  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 942    0    2    1    0    6   16    0   10    3]
 [   0 1106    6    2    0    4    6    0   10    1]
 [   6    3  901   61    5   16   19    8   10    3]
 [   0    3   39  930    0   16    0   12    6    4]
 [   0    2    4    1  907    1    3   10    6   48]
 [   2    4   10   47    3  788    5    3   23    7]
 [   7    3   10    1   10    9  910    1    7    0]
 [   0   14   17   12   13    3    0  928    1   40]
 [  22    8    9   23    7   26   10    3  859    7]
 [   4    9    6   16   35    5    0   31   14  889]]
{'eval_loss': 0.2673957645893097, 'eval_f1': 0.9159655975632608, 'eval_precision': 0.9163724161435439, 'eval_recall': 0.916, 'eval_accuracy': 0.916, 'eval_runtime': 2.0843, 'eval_samples_per_second': 4797.723, 'eval_steps_per_second': 75.324, 'epoch': 9.0}
{'loss': 0.2748, 'grad_norm': 2.7655932903289795, 'learning_rate': 1.1011964332097113e-05, 'epoch': 9.06}
{'loss': 0.264, 'grad_norm': 2.1739659309387207, 'learning_rate': 2.065770110498438e-06,

  0%|          | 0/157 [00:00<?, ?it/s]

Checkpoint destination directory ./results\checkpoint-9380 already exists and is non-empty. Saving will proceed but saved results may be invalid.


Confusion Matrix:
[[ 945    0    2    1    0    5   15    0    9    3]
 [   0 1105    6    2    0    4    6    0   11    1]
 [   6    3  896   61    5   17   22    9   10    3]
 [   0    3   40  925    0   16    0   12   10    4]
 [   0    2    4    1  907    1    3    9    6   49]
 [   2    4   11   47    3  786    5    3   24    7]
 [   7    3   10    1   10    9  910    1    7    0]
 [   0   14   17   13   13    3    0  925    1   42]
 [  24    9    9   22    7   22   10    3  861    7]
 [   5    8    6   12   36    5    0   30   15  892]]
{'eval_loss': 0.2666623294353485, 'eval_f1': 0.9151400967608606, 'eval_precision': 0.9154825760974205, 'eval_recall': 0.9152, 'eval_accuracy': 0.9152, 'eval_runtime': 2.1742, 'eval_samples_per_second': 4599.459, 'eval_steps_per_second': 72.212, 'epoch': 10.0}
{'train_runtime': 194.9569, 'train_samples_per_second': 3077.603, 'train_steps_per_second': 48.113, 'train_loss': 0.45974236998730883, 'epoch': 10.0}


  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 942    0    2    1    0    6   16    0   10    3]
 [   0 1106    6    2    0    4    6    0   10    1]
 [   6    3  901   61    5   16   19    8   10    3]
 [   0    3   39  930    0   16    0   12    6    4]
 [   0    2    4    1  907    1    3   10    6   48]
 [   2    4   10   47    3  788    5    3   23    7]
 [   7    3   10    1   10    9  910    1    7    0]
 [   0   14   17   12   13    3    0  928    1   40]
 [  22    8    9   23    7   26   10    3  859    7]
 [   4    9    6   16   35    5    0   31   14  889]]


{'eval_loss': 0.2673957645893097,
 'eval_f1': 0.9159655975632608,
 'eval_precision': 0.9163724161435439,
 'eval_recall': 0.916,
 'eval_accuracy': 0.916,
 'eval_runtime': 2.2432,
 'eval_samples_per_second': 4457.904,
 'eval_steps_per_second': 69.989,
 'epoch': 10.0}

In [7]:
trainer.train()

  0%|          | 0/9380 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 1.5768, 'grad_norm': 6.464507102966309, 'learning_rate': 0.0004977113991447017, 'epoch': 0.53}


  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 877    0    6    1   24    7   22    1   41    1]
 [   0 1027   16   13    3   46    6    7   15    2]
 [  50    5  648   57   40   87   65    3   69    8]
 [   2    4  105  701   11   52    2   18   48   67]
 [  43    4   18    2  659    9    4    8  163   72]
 [  33   11   44   54   37  581   14    6   87   25]
 [  18    2   15    2    3    5  846    0   67    0]
 [   0   12    2   12   45   20    0  696    5  236]
 [  55    1   12   10   21   22   10    3  823   17]
 [   8    4    1    9  170    3    1   15   48  750]]
{'eval_loss': 0.7616570591926575, 'eval_f1': 0.7619962587705958, 'eval_precision': 0.7746075940242815, 'eval_recall': 0.7608, 'eval_accuracy': 0.7608, 'eval_runtime': 3.6155, 'eval_samples_per_second': 2765.841, 'eval_steps_per_second': 43.424, 'epoch': 1.0}
{'loss': 0.8529, 'grad_norm': 10.953009605407715, 'learning_rate': 0.0004884857252366847, 'epoch': 1.07}
{'loss': 0.6767, 'grad_norm': 20.91096305847168, 'learning_rate': 0.0004724434323532821

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 805    0   11    2   28   34   29    5   59    7]
 [   0 1079    4    4    3   11    4   16   14    0]
 [  18   19  750  106   10   76   17   13   18    5]
 [   0    7   29  821    2   52    3   49   17   30]
 [  17    6    9    5  688   50    8   27   57  115]
 [   2   11   21   63   10  734   15   12   19    5]
 [   5    8   18    0   12   26  880    0    9    0]
 [   0   10    5   21   13   19    0  879    1   80]
 [  23    3    9   16   16   78   23    1  771   34]
 [   1    1    0   27   29   26    1   98   16  810]]
{'eval_loss': 0.5544259548187256, 'eval_f1': 0.8223033609959465, 'eval_precision': 0.8283879974033689, 'eval_recall': 0.8217, 'eval_accuracy': 0.8217, 'eval_runtime': 3.6611, 'eval_samples_per_second': 2731.41, 'eval_steps_per_second': 42.883, 'epoch': 2.0}
{'loss': 0.5864, 'grad_norm': 8.617034912109375, 'learning_rate': 0.00045004305610692587, 'epoch': 2.13}
{'loss': 0.492, 'grad_norm': 15.679028511047363, 'learning_rate': 0.00042192486471335583

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 893    0   25    0    9    3   10    4   31    5]
 [   0 1108   13    2    1    2    3    2    3    1]
 [  11   11  948   18    4   26    7    4    2    1]
 [   0   10  146  780    2   32    1   25    6    8]
 [  22    3   30    8  722   20    8   24   57   88]
 [   8    3   64   33    3  748    9    3   15    6]
 [   4    2   15    2   10    6  909    0   10    0]
 [   0   14   36   15    9    8    0  857    3   86]
 [  34    5   11   21   17   24   20    1  807   34]
 [   5    1    5   21   23   17    0   62    8  867]]
{'eval_loss': 0.4066694974899292, 'eval_f1': 0.863750619304394, 'eval_precision': 0.8684244303991644, 'eval_recall': 0.8639, 'eval_accuracy': 0.8639, 'eval_runtime': 3.6374, 'eval_samples_per_second': 2749.242, 'eval_steps_per_second': 43.163, 'epoch': 3.0}
{'loss': 0.4436, 'grad_norm': 5.626135349273682, 'learning_rate': 0.00038889255825490053, 'epoch': 3.2}
{'loss': 0.3974, 'grad_norm': 5.975671291351318, 'learning_rate': 0.00035189029658340025,

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 936    0    4    0   11    1    4    2   18    4]
 [   0 1120    5    2    1    0    5    0    2    0]
 [  16   14  922   54    3    4   13    3    3    0]
 [   1   31   32  893    3    2   12   30    3    3]
 [  18    5   12   11  853    7   21    8   13   34]
 [  15   15   79   79    7  641   29    4   14    9]
 [   8    3    8    1   14    1  915    0    8    0]
 [   0   11   16   39   19    0    0  863    4   76]
 [  45   15   10   27   21    8   34    5  796   13]
 [  10    8    2   14   77    7    1   29   13  848]]
{'eval_loss': 0.3845398724079132, 'eval_f1': 0.8775848509542602, 'eval_precision': 0.8816738631728249, 'eval_recall': 0.8787, 'eval_accuracy': 0.8787, 'eval_runtime': 3.6513, 'eval_samples_per_second': 2738.768, 'eval_steps_per_second': 42.999, 'epoch': 4.0}
{'loss': 0.3618, 'grad_norm': 5.224878787994385, 'learning_rate': 0.00031197571247243834, 'epoch': 4.26}
{'loss': 0.325, 'grad_norm': 9.919211387634277, 'learning_rate': 0.00027028968138185784

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 932    0    5    0   16    1   16    2    7    1]
 [   0 1114    5    2    0    0    3    0   11    0]
 [   8   11  954   23    5    4    8    8    7    4]
 [   0   18   37  897    0   18    3   17    9   11]
 [  10    8   12    1  829    5   20    5   37   55]
 [   7    4   55   51   11  695   32    3   17   17]
 [   5    2    8    1    7    2  926    1    6    0]
 [   0   11    9   11   17    2    0  902    5   71]
 [  22    4   11   17    3   14   37    0  858    8]
 [   4    5    2   14   43    5    1   22   28  885]]
{'eval_loss': 0.3215293288230896, 'eval_f1': 0.8986272841009766, 'eval_precision': 0.9002846041239704, 'eval_recall': 0.8992, 'eval_accuracy': 0.8992, 'eval_runtime': 3.6957, 'eval_samples_per_second': 2705.867, 'eval_steps_per_second': 42.482, 'epoch': 5.0}
{'loss': 0.3149, 'grad_norm': 14.113794326782227, 'learning_rate': 0.00022802371190303695, 'epoch': 5.33}
{'loss': 0.301, 'grad_norm': 13.288413047790527, 'learning_rate': 0.000186385888961295

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 957    0    1    0    2    4    7    1    7    1]
 [   0 1122    3    1    1    0    4    1    3    0]
 [  16   27  861   57    5   30    9   17    6    4]
 [   2   13   16  935    1   12    0   24    3    4]
 [  10    5    3    3  827   24   14    8   10   78]
 [   4    4    3   40    2  812    7    2   14    4]
 [   6    3    3    4    7    7  926    0    2    0]
 [   1    8   10   12   20    5    0  888    4   80]
 [  31    5    2   15    3   35   37    6  821   19]
 [  14    4    1    7   25   12    1   13   13  919]]
{'eval_loss': 0.2990776300430298, 'eval_f1': 0.9063381509815707, 'eval_precision': 0.9087926569387182, 'eval_recall': 0.9068, 'eval_accuracy': 0.9068, 'eval_runtime': 3.5378, 'eval_samples_per_second': 2826.646, 'eval_steps_per_second': 44.378, 'epoch': 6.0}
{'loss': 0.2858, 'grad_norm': 4.608861923217773, 'learning_rate': 0.0001465663432182349, 'epoch': 6.4}
{'loss': 0.2742, 'grad_norm': 9.254112243652344, 'learning_rate': 0.00010970323365940444,

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 925    0    7    1   13    4    1    1   23    5]
 [   0 1112    3    5    0    3    4    4    3    1]
 [  10    8  886   80   11   15    5   14    2    1]
 [   0    8   12  945    1   21    0   17    3    3]
 [   2    3    6    3  815    6    7   21    7  112]
 [   2    4   14   50    8  775    2    3   19   15]
 [   5    2    2    4   23   11  899    0   11    1]
 [   0    9    8   20    6    2    0  914    1   68]
 [  15    5    2   35    9   16    8    3  855   26]
 [   5    2    1   18    9    5    0   20    5  944]]
{'eval_loss': 0.27679070830345154, 'eval_f1': 0.907416436375048, 'eval_precision': 0.9107548690543191, 'eval_recall': 0.907, 'eval_accuracy': 0.907, 'eval_runtime': 3.6525, 'eval_samples_per_second': 2737.86, 'eval_steps_per_second': 42.984, 'epoch': 7.0}
{'loss': 0.2604, 'grad_norm': 19.479860305786133, 'learning_rate': 7.685021568435078e-05, 'epoch': 7.46}
{'loss': 0.2514, 'grad_norm': 12.069538116455078, 'learning_rate': 4.894632455610773e-05, 

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 943    0    6    0    5    5    4    1   13    3]
 [   0 1112    3    2    0    0    3    0   14    1]
 [  10   10  908   39    8   26    3   17   10    1]
 [   3    8   16  920    0   24    1   17   12    9]
 [   4    2    2    2  826   11    6   32   23   74]
 [   2    2    5   30    5  804    9    2   29    4]
 [   3    4    3    3   14   11  909    0   11    0]
 [   0    8    5   15    5    3    0  951    4   37]
 [  17    3    1    8    3   15    6    2  907   12]
 [   7    1    0   11   11   10    0   40   21  908]]
{'eval_loss': 0.26105132699012756, 'eval_f1': 0.9188035293363653, 'eval_precision': 0.9200859889091951, 'eval_recall': 0.9188, 'eval_accuracy': 0.9188, 'eval_runtime': 3.6244, 'eval_samples_per_second': 2759.085, 'eval_steps_per_second': 43.318, 'epoch': 8.0}
{'loss': 0.2416, 'grad_norm': 8.064125061035156, 'learning_rate': 2.6789135029152173e-05, 'epoch': 8.53}


  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 925    0    4    0    9    4    8    2   21    7]
 [   0 1116    2    1    1    4    5    0    6    0]
 [   8   11  944   23    6   19    5   10    5    1]
 [   1    9   16  943    0   24    0   10    5    2]
 [   3    3    7    3  875   13   11    6   19   42]
 [   5    2    8   27    2  824    5    1   12    6]
 [   4    2    2    4   10    9  920    0    7    0]
 [   0   15   14   13   21    2    0  900    6   57]
 [  12    4    4   17    2   22    9    0  900    4]
 [   6    4    1   14   30   12    0   12   16  914]]
{'eval_loss': 0.242593914270401, 'eval_f1': 0.9260794668701173, 'eval_precision': 0.9267383853604548, 'eval_recall': 0.9261, 'eval_accuracy': 0.9261, 'eval_runtime': 3.6315, 'eval_samples_per_second': 2753.698, 'eval_steps_per_second': 43.233, 'epoch': 9.0}
{'loss': 0.2442, 'grad_norm': 4.718122959136963, 'learning_rate': 1.1011964332097113e-05, 'epoch': 9.06}
{'loss': 0.2297, 'grad_norm': 4.965157985687256, 'learning_rate': 2.065770110498438e-06,

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 944    0    5    0    6    3    7    2   12    1]
 [   0 1119    3    2    0    1    4    0    6    0]
 [   8    9  968   20    2   11    4    8    2    0]
 [   2    8   21  949    0   12    1   10    4    3]
 [   2    5   11    2  897   10   12    7   10   26]
 [   3    3   15   35    2  808    8    1   12    5]
 [   5    2    6    2    8    6  922    0    7    0]
 [   0   15   16    9   18    2    0  919    5   44]
 [  23    3    5   22    6   18    9    0  882    6]
 [   8    3    4   15   57    9    0   22   14  877]]
{'eval_loss': 0.23178459703922272, 'eval_f1': 0.9282962114422455, 'eval_precision': 0.9285663043775354, 'eval_recall': 0.9285, 'eval_accuracy': 0.9285, 'eval_runtime': 3.5961, 'eval_samples_per_second': 2780.803, 'eval_steps_per_second': 43.659, 'epoch': 10.0}
{'train_runtime': 340.4996, 'train_samples_per_second': 1762.117, 'train_steps_per_second': 27.548, 'train_loss': 0.44155973503584545, 'epoch': 10.0}


TrainOutput(global_step=9380, training_loss=0.44155973503584545, metrics={'train_runtime': 340.4996, 'train_samples_per_second': 1762.117, 'train_steps_per_second': 27.548, 'train_loss': 0.44155973503584545, 'epoch': 10.0})

In [8]:
trainer.evaluate()

  0%|          | 0/157 [00:00<?, ?it/s]

Confusion Matrix:
[[ 944    0    5    0    6    3    7    2   12    1]
 [   0 1119    3    2    0    1    4    0    6    0]
 [   8    9  968   20    2   11    4    8    2    0]
 [   2    8   21  949    0   12    1   10    4    3]
 [   2    5   11    2  897   10   12    7   10   26]
 [   3    3   15   35    2  808    8    1   12    5]
 [   5    2    6    2    8    6  922    0    7    0]
 [   0   15   16    9   18    2    0  919    5   44]
 [  23    3    5   22    6   18    9    0  882    6]
 [   8    3    4   15   57    9    0   22   14  877]]


{'eval_loss': 0.23178459703922272,
 'eval_f1': 0.9282962114422455,
 'eval_precision': 0.9285663043775354,
 'eval_recall': 0.9285,
 'eval_accuracy': 0.9285,
 'eval_runtime': 3.4166,
 'eval_samples_per_second': 2926.901,
 'eval_steps_per_second': 45.952,
 'epoch': 10.0}

In [6]:
import torch
from torch import nn
from torchvision import transforms, datasets
from bitlinear import BitLinear as Linear

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size = 784
hidden_size = 100
num_classes = 10
learning_rate = 0.001
num_epochs = 5


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = Linear(input_size, hidden_size)
        self.fc2 = Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.log_softmax(self.fc2(x), dim=1)
        return x


# Create the model and optimizer
model = Net(input_size, hidden_size, num_classes)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.NLLLoss()

# Train the model
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images.view(-1, 28 * 28).to(device)).cpu()
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward(retain_graph=False)
        optimizer.step()

        # Print training information
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

Epoch [1/5], Step [100/938], Loss: 2.1903
Epoch [1/5], Step [200/938], Loss: 2.0577
Epoch [1/5], Step [300/938], Loss: 1.9495
Epoch [1/5], Step [400/938], Loss: 1.8025
Epoch [1/5], Step [500/938], Loss: 1.6161
Epoch [1/5], Step [600/938], Loss: 1.5318
Epoch [1/5], Step [700/938], Loss: 1.4546
Epoch [1/5], Step [800/938], Loss: 1.3805
Epoch [1/5], Step [900/938], Loss: 1.2134
Epoch [2/5], Step [100/938], Loss: 1.2599
Epoch [2/5], Step [200/938], Loss: 1.0751
Epoch [2/5], Step [300/938], Loss: 1.2669
Epoch [2/5], Step [400/938], Loss: 1.0585
Epoch [2/5], Step [500/938], Loss: 1.0870
Epoch [2/5], Step [600/938], Loss: 0.9355
Epoch [2/5], Step [700/938], Loss: 0.9594
Epoch [2/5], Step [800/938], Loss: 0.9366
Epoch [2/5], Step [900/938], Loss: 0.9701
Epoch [3/5], Step [100/938], Loss: 0.9047
Epoch [3/5], Step [200/938], Loss: 0.7651
Epoch [3/5], Step [300/938], Loss: 0.8008
Epoch [3/5], Step [400/938], Loss: 0.7867
Epoch [3/5], Step [500/938], Loss: 0.8615
Epoch [3/5], Step [600/938], Loss:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)

In [8]:
# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images.view(-1, 28 * 28).to(device)).cpu()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

Accuracy of the network on the 10000 test images: 89.72%
