In [1]:
import torch
from config import BitformerConfig
from model_zoo import VisionBitformerForImageClassification, VisionBitformerForSemanticSegmentation
from metrics import compute_metrics_single_label_classification
from trainer import get_trainer
from data_zoo import get_vision_dataset, vision_collator
from utils import get_yaml

In [2]:
yaml_path = './yamls/vision/small.yaml'
args = get_yaml(yaml_path)

In [10]:
args['model_config']['bitnet'] = False

In [3]:
cfg = BitformerConfig(**args['model_config'], num_labels=args['general_config']['num_labels'])
model = VisionBitformerForImageClassification(config=cfg)
model.num_parameters() / 1e6

1.131476

In [4]:
data_path = 'mnist'
train_dataset, test_dataset = get_vision_dataset(data_path)

In [5]:
trainer = get_trainer(model=model,
                      train_dataset=train_dataset,
                      valid_dataset=test_dataset,
                      compute_metrics=compute_metrics_single_label_classification,
                      data_collator=vision_collator,
                      **args['training_args'])

In [6]:
trainer.train()
trainer.evaluate()

  0%|          | 0/9380 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 1.0461, 'grad_norm': 6.677570343017578, 'learning_rate': 0.0004900072000534774, 'epoch': 1.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.5116379261016846, 'eval_f1': 0.8410506357325556, 'eval_precision': 0.8495047742662352, 'eval_recall': 0.8411, 'eval_accuracy': 0.8411, 'eval_runtime': 2.1765, 'eval_samples_per_second': 4594.562, 'eval_steps_per_second': 72.135, 'epoch': 1.0}


KeyboardInterrupt: 

In [13]:
trainer.train()
trainer.evaluate()

  0%|          | 0/9380 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 0.8597, 'grad_norm': 1.9200587272644043, 'learning_rate': 0.0004900072000534774, 'epoch': 1.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.3140251636505127, 'eval_f1': 0.9135553046837643, 'eval_precision': 0.9142339999560493, 'eval_recall': 0.9139, 'eval_accuracy': 0.9139, 'eval_runtime': 0.7525, 'eval_samples_per_second': 13288.821, 'eval_steps_per_second': 208.634, 'epoch': 1.0}
{'loss': 0.2633, 'grad_norm': 1.5263727903366089, 'learning_rate': 0.00045615929491282483, 'epoch': 2.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.21613578498363495, 'eval_f1': 0.9358811972883226, 'eval_precision': 0.9367202103839279, 'eval_recall': 0.9359, 'eval_accuracy': 0.9359, 'eval_runtime': 0.7631, 'eval_samples_per_second': 13104.694, 'eval_steps_per_second': 205.744, 'epoch': 2.0}
{'loss': 0.1922, 'grad_norm': 1.6930099725723267, 'learning_rate': 0.00040169749790445905, 'epoch': 3.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.17053326964378357, 'eval_f1': 0.9475878357234756, 'eval_precision': 0.9478322857344401, 'eval_recall': 0.9476, 'eval_accuracy': 0.9476, 'eval_runtime': 0.7619, 'eval_samples_per_second': 13125.625, 'eval_steps_per_second': 206.072, 'epoch': 3.0}
{'loss': 0.1612, 'grad_norm': 0.7695346474647522, 'learning_rate': 0.0003320674504433184, 'epoch': 4.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.15335500240325928, 'eval_f1': 0.9524743689204, 'eval_precision': 0.95299495516058, 'eval_recall': 0.9525, 'eval_accuracy': 0.9525, 'eval_runtime': 0.7715, 'eval_samples_per_second': 12962.49, 'eval_steps_per_second': 203.511, 'epoch': 4.0}
{'loss': 0.1419, 'grad_norm': 1.5058592557907104, 'learning_rate': 0.000254231469070671, 'epoch': 5.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.1420908421278, 'eval_f1': 0.9574239546636701, 'eval_precision': 0.9578308720308931, 'eval_recall': 0.9574, 'eval_accuracy': 0.9574, 'eval_runtime': 0.7818, 'eval_samples_per_second': 12790.581, 'eval_steps_per_second': 200.812, 'epoch': 5.0}
{'loss': 0.1283, 'grad_norm': 0.12763918936252594, 'learning_rate': 0.00017597238261366985, 'epoch': 6.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.129969522356987, 'eval_f1': 0.961242367614776, 'eval_precision': 0.9614252933581583, 'eval_recall': 0.9613, 'eval_accuracy': 0.9613, 'eval_runtime': 0.7893, 'eval_samples_per_second': 12669.037, 'eval_steps_per_second': 198.904, 'epoch': 6.0}
{'loss': 0.1174, 'grad_norm': 2.4719347953796387, 'learning_rate': 0.00010511532622604558, 'epoch': 7.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.12279283255338669, 'eval_f1': 0.9640682937797396, 'eval_precision': 0.9640990308197468, 'eval_recall': 0.9641, 'eval_accuracy': 0.9641, 'eval_runtime': 0.7619, 'eval_samples_per_second': 13125.05, 'eval_steps_per_second': 206.063, 'epoch': 7.0}
{'loss': 0.1106, 'grad_norm': 1.9979044198989868, 'learning_rate': 4.8745305214285054e-05, 'epoch': 8.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.12025178223848343, 'eval_f1': 0.9634540007925865, 'eval_precision': 0.9635174928552249, 'eval_recall': 0.9635, 'eval_accuracy': 0.9635, 'eval_runtime': 0.7837, 'eval_samples_per_second': 12760.486, 'eval_steps_per_second': 200.34, 'epoch': 8.0}
{'loss': 0.1056, 'grad_norm': 1.1408836841583252, 'learning_rate': 1.2498764533288465e-05, 'epoch': 9.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.11933530122041702, 'eval_f1': 0.9643829509011197, 'eval_precision': 0.9644782761401038, 'eval_recall': 0.9644, 'eval_accuracy': 0.9644, 'eval_runtime': 0.7708, 'eval_samples_per_second': 12973.777, 'eval_steps_per_second': 203.688, 'epoch': 9.0}
{'loss': 0.1031, 'grad_norm': 2.4104082584381104, 'learning_rate': 0.0, 'epoch': 10.0}


  0%|          | 0/157 [00:00<?, ?it/s]

Checkpoint destination directory ./results\checkpoint-9380 already exists and is non-empty. Saving will proceed but saved results may be invalid.


{'eval_loss': 0.11868536472320557, 'eval_f1': 0.9640774850942858, 'eval_precision': 0.9641373508731603, 'eval_recall': 0.9641, 'eval_accuracy': 0.9641, 'eval_runtime': 0.797, 'eval_samples_per_second': 12547.586, 'eval_steps_per_second': 196.997, 'epoch': 10.0}
{'train_runtime': 112.5335, 'train_samples_per_second': 5331.745, 'train_steps_per_second': 83.353, 'train_loss': 0.2183222813392753, 'epoch': 10.0}


  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.11933530122041702,
 'eval_f1': 0.9643829509011197,
 'eval_precision': 0.9644782761401038,
 'eval_recall': 0.9644,
 'eval_accuracy': 0.9644,
 'eval_runtime': 0.7916,
 'eval_samples_per_second': 12633.389,
 'eval_steps_per_second': 198.344,
 'epoch': 10.0}

In [6]:
import torch
import torch.nn as nn

x = torch.rand(64, 24, 14, 14)

x = nn.AvgPool2d(2)(x)
print(x.shape)
x = nn.Upsample(scale_factor=2)(x)
print(x.shape)
x = nn.AvgPool2d(4)(x)
print(x.shape)
x = nn.Upsample(scale_factor=4)(x)
print(x.shape)
x = nn.MaxPool2d(2)(x)
print(x.shape)
x = nn.Upsample(scale_factor=2)(x)
print(x.shape)
x = nn.MaxPool2d(4)(x)
print(x.shape)
x = nn.Upsample(scale_factor=4)(x)
print(x.shape)

torch.Size([64, 24, 7, 7])
torch.Size([64, 24, 14, 14])
torch.Size([64, 24, 3, 3])
torch.Size([64, 24, 12, 12])
torch.Size([64, 24, 6, 6])
torch.Size([64, 24, 12, 12])
torch.Size([64, 24, 3, 3])
torch.Size([64, 24, 12, 12])


In [16]:
x = torch.rand(64, 24, 14, 14)

class HANCLayer(nn.Module):
    """
    Implements Hierarchical Aggregation of Neighborhood Context operation
    """

    def __init__(self, in_chnl, out_chnl, k):
        """
        Initialization

        Args:
            in_chnl (int): number of input channels
            out_chnl (int): number of output channels
            k (int): value of k in HANC
        """

        super(HANCLayer, self).__init__()

        self.k = k

        self.cnv = nn.Conv2d((2 * k - 1) * in_chnl, out_chnl, kernel_size=(1, 1))
        self.act = nn.LeakyReLU()
        self.bn = nn.BatchNorm2d(out_chnl)


    def forward(self, inp):
        batch_size, num_channels, H, W = inp.size()
        x = inp

        if self.k == 1:
            x = inp

        elif self.k == 2:
            H_2 = H // 2
            W_2 = W // 2
            x = torch.concat(
                [
                    x,
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_2, W_2))(x)),
                ],
                dim=2,
            )

        elif self.k == 3:
            H_2 = H // 2
            W_2 = W // 2
            H_4 = H // 4
            W_4 = W // 4
            x = torch.concat(
                [
                    x,
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_4, W_4))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_4, W_4))(x)),
                ],
                dim=2,
            )

        elif self.k == 4:
            H_2 = H // 2
            W_2 = W // 2
            H_4 = H // 4
            W_4 = W // 4
            H_8 = H // 8
            W_8 = W // 8
            x = torch.concat(
                [
                    x,
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_4, W_4))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_8, W_8))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_4, W_4))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_8, W_8))(x)),
                ],
                dim=2,
            )

        elif self.k == 5:
            H_2 = H // 2
            W_2 = W // 2
            H_4 = H // 4
            W_4 = W // 4
            H_8 = H // 8
            W_8 = W // 8
            H_16 = H // 16
            W_16 = W // 16
            x = torch.concat(
                [
                    x,
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_4, W_4))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_8, W_8))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveAvgPool2d((H_16, W_16))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_2, W_2))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_4, W_4))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_8, W_8))(x)),
                    nn.Upsample(size=(H, W))(nn.AdaptiveMaxPool2d((H_16, W_16))(x)),
                ],
                dim=2,
            )

        x = x.view(batch_size, num_channels * (2 * self.k - 1), H, W)
        x = self.act(self.bn(self.cnv(x)))

        return x

torch.Size([64, 24, 42, 14])