In [1]:
from dataclasses import dataclass
from datetime import datetime

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from torchfuzzy import FuzzyLayer, DefuzzyLinearLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 256
learning_rate = 1e-4
weight_decay = 1e-2
labels_count = 10
num_epochs = 20
fuzzy_dim = 2
fuzzy_rules_count = 10

prefix = "mamdani_mnist"
writer = SummaryWriter(f'runs/mnist/{prefix}_{datetime.now().strftime("%Y%m%d-%H%M%S")}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
# inp = torch.rand((10,3))
# df = DefuzzyLinearLayer.from_dimensions(3, 2)
# df.forward(inp),2)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.view(-1, 28, 28) - 0.5),
])

In [5]:
def get_target(target_label):
    """
    Возвращает вектор целевого значения

    Args:
        target_label (int): Метка класса
    
    Returns:
        tensor (1, 10)
    """
    t = F.one_hot(torch.LongTensor([target_label]), labels_count)
    return t.to(device)

In [6]:
# загружаем обучающую выборку
train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=True, 
    transform = transform,
    target_transform = transforms.Lambda(lambda x: get_target(x))
)

In [7]:
# загружаем тестовую выборку
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=False, 
    transform=transform, 
    target_transform = transforms.Lambda(lambda x: get_target(x))
)
len(test_data)

10000

In [8]:
# Создаем итераторы датасетов
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True,
    
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

In [9]:
class Encoder(nn.Module):
    """
    Компонент энкодера
    
    Args:
        fuzzy_dim (int): Размер латентного вектора.
    """
    
    def __init__(self, fuzzy_dim):
        super(Encoder, self).__init__()
                
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(8, 16, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(16, 32, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(32, 64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.SiLU(),  
            nn.Flatten(),
            nn.Linear(9216, 625),
            nn.BatchNorm1d(625),
            nn.SiLU(),  
            nn.Linear(625, fuzzy_dim),
        )
         
    def forward(self, x):
        """
        Выход энкодера
        
        Args:
            x (torch.Tensor): Входной вектор.
        
        Returns:
            encoded input
        """

        ex = self.encoder(x)
        
        return ex

In [10]:
class MamdaniFIS(nn.Module):
    """
    MamdaniFIS
    
    Args:
        fuzzy_dim (int): Размер латентного вектора.
        labels_count (int): Количество выходов классификатора
    """
    def __init__(self, fuzzy_dim, fuzzy_rules_count, labels_count):
        super(MamdaniFIS, self).__init__()

        self.encoder = Encoder(fuzzy_dim)        
        
        self.fuzzy = nn.Sequential(
            FuzzyLayer.from_dimensions(fuzzy_dim, fuzzy_rules_count, trainable=True),
            DefuzzyLinearLayer.from_dimensions(fuzzy_rules_count, labels_count)
        )
        
    def forward(self, x):
        """
        
        
        Args:
            x (torch.Tensor): Входной вектор.
        
        Returns:
            labels
        """

        ex = self.encoder(x)
        labels = self.fuzzy(ex)

        return labels

In [11]:
def compute_loss(target_labels, predicted_labels):
    
    #print(torch.squeeze(target_labels,1))
    #print(predicted_labels)
    ceLoss = nn.CrossEntropyLoss()
    
    loss_fuzzy = ceLoss.forward(predicted_labels, torch.squeeze(target_labels,1).float())

    return loss_fuzzy
    

In [12]:
model = MamdaniFIS(fuzzy_dim, fuzzy_rules_count, labels_count).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params:,}')

model

Number of parameters: 5,830,935


MamdaniFIS(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
      (1): SiLU()
      (2): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))
      (3): SiLU()
      (4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
      (5): SiLU()
      (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): SiLU()
      (9): Flatten(start_dim=1, end_dim=-1)
      (10): Linear(in_features=9216, out_features=625, bias=True)
      (11): BatchNorm1d(625, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): SiLU()
      (13): Linear(in_features=625, out_features=2, bias=True)
    )
  )
  (fuzzy): Sequential(
    (0): FuzzyLayer()
    (1): DefuzzyLinearLayer()
  )
)

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [14]:
def train(model, dataloader, optimizer, prev_updates, writer=None):
    model.train()  
    
    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx
        
        data = data.to(device)
        
        optimizer.zero_grad()  
        
        labels = model.forward(data)  
        
        loss = compute_loss(target, labels)
        
        loss.backward()
        
        if n_upd % 100 == 0:
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} Grad: {total_norm:.4f}')

            if writer is not None:
                global_step = n_upd
                writer.add_scalar('Loss/Train', loss.item(), global_step)
                writer.add_scalar('GradNorm/Train', total_norm, global_step)
            
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  
        
    return prev_updates + len(dataloader)

In [15]:
def test(model, dataloader, cur_step, writer=None):
    model.eval() 
    test_loss = 0
    test_accuracy = 0
    
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            
            labels = model.forward(data)  
            
            loss = compute_loss(target, labels)
            
            test_loss += loss.item()
            
            pred_target = np.argmax(labels.cpu().numpy(), axis=1)
            target_labels =  np.argmax(torch.squeeze(target,1).cpu().numpy(), axis=1)
            test_accuracy += np.sum(target_labels==pred_target) / len(pred_target)

    test_loss /= len(dataloader)
    test_accuracy /= len(dataloader)

    print(f'====> Test set loss: {test_loss:.4f} (Accuracy {test_accuracy:.4f})')
    
    if writer is not None:
        writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
        writer.add_scalar('Fuzzy/Test/Accuracy', test_accuracy, global_step=cur_step)

In [16]:
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

Epoch 1/20


  0%|          | 1/235 [00:01<04:25,  1.14s/it]

Step 0 (N samples: 0), Loss: 2.3049 Grad: 0.5766


 44%|████▍     | 104/235 [00:42<00:06, 20.38it/s]

Step 100 (N samples: 25,600), Loss: 2.1859 Grad: 0.2704


 86%|████████▋ | 203/235 [00:47<00:01, 21.72it/s]

Step 200 (N samples: 51,200), Loss: 2.1533 Grad: 0.3713


100%|██████████| 235/235 [00:48<00:00,  4.85it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.27it/s]


====> Test set loss: 2.1466 (Accuracy 0.3293)
Epoch 2/20


 29%|██▉       | 69/235 [00:03<00:07, 22.20it/s]

Step 300 (N samples: 76,800), Loss: 2.1230 Grad: 0.4119


 71%|███████▏  | 168/235 [00:07<00:03, 21.82it/s]

Step 400 (N samples: 102,400), Loss: 2.0593 Grad: 0.4149


100%|██████████| 235/235 [00:10<00:00, 21.92it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.10it/s]


====> Test set loss: 2.0414 (Accuracy 0.3630)
Epoch 3/20


 14%|█▍        | 33/235 [00:01<00:09, 20.81it/s]

Step 500 (N samples: 128,000), Loss: 2.0306 Grad: 0.4012


 57%|█████▋    | 135/235 [00:06<00:04, 21.21it/s]

Step 600 (N samples: 153,600), Loss: 1.9975 Grad: 0.3501


100%|██████████| 235/235 [00:10<00:00, 21.80it/s]


Step 700 (N samples: 179,200), Loss: 1.9888 Grad: 0.3375


Testing: 100%|██████████| 40/40 [00:01<00:00, 25.76it/s]


====> Test set loss: 1.9744 (Accuracy 0.4573)
Epoch 4/20


 42%|████▏     | 99/235 [00:04<00:06, 22.00it/s]

Step 800 (N samples: 204,800), Loss: 1.9660 Grad: 0.4300


 84%|████████▍ | 198/235 [00:08<00:01, 21.69it/s]

Step 900 (N samples: 230,400), Loss: 1.9201 Grad: 0.3927


100%|██████████| 235/235 [00:10<00:00, 22.15it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.17it/s]


====> Test set loss: 1.9310 (Accuracy 0.4745)
Epoch 5/20


 27%|██▋       | 63/235 [00:02<00:07, 22.18it/s]

Step 1,000 (N samples: 256,000), Loss: 1.9049 Grad: 0.4731


 70%|███████   | 165/235 [00:07<00:03, 21.50it/s]

Step 1,100 (N samples: 281,600), Loss: 1.9117 Grad: 0.5157


100%|██████████| 235/235 [00:10<00:00, 22.05it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.06it/s]


====> Test set loss: 1.8917 (Accuracy 0.4750)
Epoch 6/20


 13%|█▎        | 30/235 [00:01<00:09, 20.95it/s]

Step 1,200 (N samples: 307,200), Loss: 1.8994 Grad: 0.4752


 54%|█████▎    | 126/235 [00:17<01:11,  1.52it/s]

Step 1,300 (N samples: 332,800), Loss: 1.8658 Grad: 0.6763


 97%|█████████▋| 229/235 [01:05<00:00, 19.17it/s]

Step 1,400 (N samples: 358,400), Loss: 1.8560 Grad: 0.4827


100%|██████████| 235/235 [01:05<00:00,  3.58it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.91it/s]


====> Test set loss: 1.8565 (Accuracy 0.4799)
Epoch 7/20


 40%|███▉      | 93/235 [00:04<00:06, 21.11it/s]

Step 1,500 (N samples: 384,000), Loss: 1.8604 Grad: 0.6064


 83%|████████▎ | 195/235 [00:09<00:01, 21.59it/s]

Step 1,600 (N samples: 409,600), Loss: 1.8220 Grad: 0.5983


100%|██████████| 235/235 [00:10<00:00, 21.60it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.74it/s]


====> Test set loss: 1.8263 (Accuracy 0.4824)
Epoch 8/20


 26%|██▌       | 60/235 [00:02<00:08, 21.72it/s]

Step 1,700 (N samples: 435,200), Loss: 1.8134 Grad: 0.4638


 68%|██████▊   | 159/235 [00:07<00:03, 21.46it/s]

Step 1,800 (N samples: 460,800), Loss: 1.8018 Grad: 0.6606


100%|██████████| 235/235 [00:10<00:00, 21.45it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.56it/s]


====> Test set loss: 1.7924 (Accuracy 0.4813)
Epoch 9/20


 10%|█         | 24/235 [00:01<00:10, 20.59it/s]

Step 1,900 (N samples: 486,400), Loss: 1.7806 Grad: 0.4027


 53%|█████▎    | 125/235 [00:06<00:05, 21.65it/s]

Step 2,000 (N samples: 512,000), Loss: 1.7661 Grad: 0.4818


 95%|█████████▌| 224/235 [00:10<00:00, 21.79it/s]

Step 2,100 (N samples: 537,600), Loss: 1.7599 Grad: 0.4732


100%|██████████| 235/235 [00:11<00:00, 20.93it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.23it/s]


====> Test set loss: 1.7696 (Accuracy 0.4757)
Epoch 10/20


 38%|███▊      | 90/235 [00:04<00:06, 20.92it/s]

Step 2,200 (N samples: 563,200), Loss: 1.7416 Grad: 0.6518


 80%|████████  | 189/235 [00:08<00:02, 22.36it/s]

Step 2,300 (N samples: 588,800), Loss: 1.7335 Grad: 0.7924


100%|██████████| 235/235 [00:10<00:00, 21.69it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.55it/s]


====> Test set loss: 1.7345 (Accuracy 0.4865)
Epoch 11/20


 23%|██▎       | 53/235 [00:02<00:09, 19.48it/s]

Step 2,400 (N samples: 614,400), Loss: 1.7366 Grad: 1.0424


 66%|██████▌   | 155/235 [00:07<00:03, 22.02it/s]

Step 2,500 (N samples: 640,000), Loss: 1.7096 Grad: 0.6270


100%|██████████| 235/235 [00:11<00:00, 21.31it/s]
Testing: 100%|██████████| 40/40 [00:25<00:00,  1.58it/s]


====> Test set loss: 1.7068 (Accuracy 0.4875)
Epoch 12/20


  7%|▋         | 16/235 [00:11<02:30,  1.46it/s]

Step 2,600 (N samples: 665,600), Loss: 1.7001 Grad: 0.6609


 51%|█████     | 119/235 [00:38<00:05, 20.35it/s]

Step 2,700 (N samples: 691,200), Loss: 1.6770 Grad: 0.5719


 93%|█████████▎| 218/235 [00:43<00:00, 20.76it/s]

Step 2,800 (N samples: 716,800), Loss: 1.6658 Grad: 0.7542


100%|██████████| 235/235 [00:44<00:00,  5.29it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.69it/s]


====> Test set loss: 1.6804 (Accuracy 0.4889)
Epoch 13/20


 36%|███▌      | 84/235 [00:03<00:07, 20.80it/s]

Step 2,900 (N samples: 742,400), Loss: 1.6632 Grad: 0.9679


 78%|███████▊  | 183/235 [00:08<00:02, 21.51it/s]

Step 3,000 (N samples: 768,000), Loss: 1.6524 Grad: 0.4198


100%|██████████| 235/235 [00:11<00:00, 21.07it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.63it/s]


====> Test set loss: 1.6545 (Accuracy 0.4887)
Epoch 14/20


 21%|██        | 49/235 [00:02<00:09, 20.63it/s]

Step 3,100 (N samples: 793,600), Loss: 1.6602 Grad: 0.8268


 64%|██████▍   | 150/235 [00:07<00:03, 21.27it/s]

Step 3,200 (N samples: 819,200), Loss: 1.6491 Grad: 1.2967


100%|██████████| 235/235 [00:11<00:00, 20.92it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.12it/s]


====> Test set loss: 1.6287 (Accuracy 0.4897)
Epoch 15/20


  6%|▋         | 15/235 [00:00<00:09, 22.29it/s]

Step 3,300 (N samples: 844,800), Loss: 1.6377 Grad: 1.0742


 49%|████▊     | 114/235 [00:05<00:05, 21.49it/s]

Step 3,400 (N samples: 870,400), Loss: 1.6027 Grad: 0.5541


 91%|█████████ | 213/235 [00:09<00:01, 21.41it/s]

Step 3,500 (N samples: 896,000), Loss: 1.6031 Grad: 0.9794


100%|██████████| 235/235 [00:10<00:00, 21.58it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.55it/s]


====> Test set loss: 1.6056 (Accuracy 0.4846)
Epoch 16/20


 33%|███▎      | 78/235 [00:03<00:07, 21.56it/s]

Step 3,600 (N samples: 921,600), Loss: 1.5885 Grad: 0.8852


 77%|███████▋  | 180/235 [00:08<00:02, 21.60it/s]

Step 3,700 (N samples: 947,200), Loss: 1.5798 Grad: 0.5909


100%|██████████| 235/235 [00:10<00:00, 21.47it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.46it/s]


====> Test set loss: 1.5806 (Accuracy 0.4921)
Epoch 17/20


 19%|█▉        | 45/235 [00:02<00:08, 21.35it/s]

Step 3,800 (N samples: 972,800), Loss: 1.5723 Grad: 1.3443


 60%|██████    | 141/235 [00:46<01:04,  1.45it/s]

Step 3,900 (N samples: 998,400), Loss: 1.5393 Grad: 1.2679


100%|██████████| 235/235 [01:08<00:00,  3.43it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 24.13it/s]


====> Test set loss: 1.5322 (Accuracy 0.4854)
Epoch 18/20


  4%|▍         | 9/235 [00:00<00:10, 21.04it/s]

Step 4,000 (N samples: 1,024,000), Loss: 1.4978 Grad: 0.9914


 46%|████▌     | 108/235 [00:05<00:06, 19.87it/s]

Step 4,100 (N samples: 1,049,600), Loss: 1.5004 Grad: 1.0967


 89%|████████▉ | 210/235 [00:09<00:01, 20.73it/s]

Step 4,200 (N samples: 1,075,200), Loss: 1.4760 Grad: 0.5114


100%|██████████| 235/235 [00:11<00:00, 21.22it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.21it/s]


====> Test set loss: 1.5017 (Accuracy 0.4888)
Epoch 19/20


 31%|███▏      | 74/235 [00:03<00:07, 21.10it/s]

Step 4,300 (N samples: 1,100,800), Loss: 1.4789 Grad: 0.6330


 74%|███████▍  | 175/235 [00:08<00:02, 22.98it/s]

Step 4,400 (N samples: 1,126,400), Loss: 1.4304 Grad: 0.6064


100%|██████████| 235/235 [00:10<00:00, 21.45it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 26.33it/s]


====> Test set loss: 1.4756 (Accuracy 0.4909)
Epoch 20/20


 17%|█▋        | 39/235 [00:01<00:08, 22.17it/s]

Step 4,500 (N samples: 1,152,000), Loss: 1.4593 Grad: 0.5149


 60%|█████▉    | 140/235 [00:06<00:04, 22.01it/s]

Step 4,600 (N samples: 1,177,600), Loss: 1.4760 Grad: 0.6516


100%|██████████| 235/235 [00:10<00:00, 21.74it/s]
Testing: 100%|██████████| 40/40 [00:01<00:00, 25.89it/s]

====> Test set loss: 1.4567 (Accuracy 0.4874)



