<a href="https://colab.research.google.com/github/kodai-utsunomiya/memorization-and-generalization/blob/main/Mechanism_of_feature_learning_in_deep_fully_connected_networks_and_kernel_machines_that_recursively_learn_features_Grokking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

（cpuでは厳しい）

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# classic_kernel

### ユークリッド距離の計算

**<font color= "green">関数 `euclidean_distances(samples, centers, squared=True)`</font>**

サンプル点と中心点とのユークリッド距離を計算する関数

- **サンプル点の行列** $ X $ の形状は $(n_{\text{sample}}, n_{\text{feature}})$
- **中心点の行列** $ C $ の形状は $(n_{\text{center}}, n_{\text{feature}})$
- **サンプル点のノルム** $ X_{\text{norm}} $ は，以下のように計算される．

  $
  X_{\text{norm}} = \sum_{i=1}^{n_{\text{feature}}} X_i^2
  $

  ここで， $ X_i $ はサンプル行列の $i$-th 列を示す．

- **中心点のノルム** $ C_{\text{norm}} $ は，中心点行列 $ C $ に対して同様に計算される．

  $
  C_{\text{norm}} = \sum_{i=1}^{n_{\text{feature}}} C_i^2
  $

- ユークリッド距離の計算式は次のようになる．
  
  $
  D_{ij} = \|X_i - C_j\|^2 = \|X_i\|^2 + \|C_j\|^2 - 2 \langle X_i, C_j \rangle
  $

  ここで，$ \langle X_i, C_j \rangle $ はサンプル点 $ X_i $ と中心点 $ C_j $ の内積．

- 行列演算で計算する場合，以下の式を用いる．
  
  $
  D = X X^T - 2 X C^T + C C^T
  $

  ここで，$ X X^T $ はサンプル点行列の各行のノルムの二乗を含む行列．

- **平方根を取る** オプションがあり，ユークリッド距離の平方根を取ることで通常の距離を求める．

**<font color= "green">関数 `euclidean_distances_M(samples, centers, M, squared=True)`</font>**

この関数は，ユークリッド距離を計算する際に，特定の行列 $ M $ を用いる．

- サンプル点のノルムは次のように変更される．
  
  $
  X_{\text{norm}} = \sum_{i=1}^{n_{\text{feature}}} (X_i M X_i^T)
  $
- 同様に，中心点のノルムも変更される．
  
  $
  C_{\text{norm}} = \sum_{i=1}^{n_{\text{feature}}} (C_i M C_i^T)
  $

- ユークリッド距離は次のように計算される．
  
  $
  D = X M (X M)^T - 2 X (M C)^T + C M (C M)^T
  $

### <font color= "green"> ガウスカーネル </font>

**<font color= "green">関数 `gaussian(samples, centers, bandwidth)`</font>**

ガウスカーネル（またはRBFカーネル）は，サンプルと中心点の間の類似性を評価する．ガウスカーネルの計算式は次のようになる．

- **距離行列** $ D $ は，ユークリッド距離に基づいて計算される．
- ガウスカーネルの計算式は次のようになる．

  $
  K_{ij} = \exp\left(-\frac{D_{ij}}{2 \sigma^2}\right)
  $

  ここで，$ \sigma $ はバンド幅（bandwidth）

### <font color= "green"> ラプラスカーネル </font>

**<font color= "green">関数 `laplacian(samples, centers, bandwidth)`</font>**

ラプラスカーネルは，ガウスカーネルの変種で，L1ノルム（マンハッタン距離）を基にしている．計算式は次のようになる．

- **距離行列** $ D $ は，ユークリッド距離の平方根を取って計算する．
- ラプラスカーネルの計算式は次のようになる．

  $
  K_{ij} = \exp\left(-\frac{D_{ij}}{\sigma}\right)
  $

  ここで，\( \sigma \) はバンド幅（bandwidth）

### 行列 $ M $ を考慮したラプラスカーネル

**<font color= "green">関数 `laplacian_M(samples, centers, bandwidth, M)`</font>**

- 行列 $ M $ を用いて距離行列 $ D $ を計算する．
- ラプラスカーネルの計算式は次のようになる．
  
  $
  K_{ij} = \exp\left(-\frac{D_{ij}}{\sigma}\right)
  $

### <font color= "green"> 分散カーネル </font>

**<font color= "green">関数 `dispersal(samples, centers, bandwidth, gamma)`</font>**

分散カーネルは，特定のパラメータ $ \gamma $ を用いたカーネルで，一般的には次のように定義される．

- **距離行列** $ D $ は，ユークリッド距離を基に計算する．
- 分散カーネルの計算式は次のようになる．

  $
  K_{ij} = \exp\left(-\frac{D_{ij}^{\gamma / 2}}{\sigma}\right)
  $

  ここで，$ \gamma $ は分散因子，$ \sigma $ はバンド幅．

In [29]:
'''Implementation of kernel functions.'''

import torch


def euclidean_distances(samples, centers, squared=True):
    samples_norm = torch.sum(samples**2, dim=1, keepdim=True)
    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = torch.sum(centers**2, dim=1, keepdim=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    # print(centers_norm.size(), samples_norm.size(), distances.size())
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances


def euclidean_distances_M(samples, centers, M, squared=True):

    samples_norm = (samples @ M)  * samples
    samples_norm = torch.sum(samples_norm, dim=1, keepdim=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        centers_norm = (centers @ M) * centers
        centers_norm = torch.sum(centers_norm, dim=1, keepdim=True)

    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(M @ torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)

    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances


def gaussian(samples, centers, bandwidth):
    '''Gaussian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    #print(samples.size(), centers.size(),
    #      kernel_mat.size())
    return kernel_mat


def laplacian(samples, centers, bandwidth):
    '''Laplacian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat



def laplacian_M(samples, centers, bandwidth, M):
    assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M, squared=False)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat


def dispersal(samples, centers, bandwidth, gamma):
    '''Dispersal kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.
        gamma: dispersal factor.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)
    kernel_mat.pow_(gamma / 2.)
    kernel_mat.mul_(-1. / bandwidth)
    kernel_mat.exp_()
    return kernel_mat

# neural_model

In [30]:
import torch
import torch.nn as nn
from torch.autograd import Variable, Function
import torch.optim as optim
from torchvision import models
from torch.nn.functional import upsample
from copy import deepcopy
import torch.nn.functional as F


class Nonlinearity(torch.nn.Module):
    def __init__(self):
        super(Nonlinearity, self).__init__()

    def forward(self, x):
        return F.relu(x)


class Net(nn.Module):

    def __init__(self, dim, num_classes=2):
        super(Net, self).__init__()
        bias = False
        k = 1024
        self.dim = dim
        self.width = k
        self.first = nn.Linear(dim, k, bias=bias)
        self.fc = nn.Sequential(Nonlinearity(),
                                nn.Linear(k, k, bias=bias),
                                Nonlinearity(),
                                nn.Linear(k, num_classes, bias=bias))

    def forward(self, x):
        return self.fc(self.first(x))

# trainer

In [31]:
import torch
from torch.autograd import Variable
import torch.optim as optim
import time
# import neural_model
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np


def visualize_M(M, idx):
    d, _ = M.shape
    SIZE = int(np.sqrt(d // 3))
    F1 = np.diag(M[:SIZE**2, :SIZE**2]).reshape(SIZE, SIZE)
    F2 = np.diag(M[SIZE**2:2*SIZE**2, SIZE**2:2*SIZE**2]).reshape(SIZE, SIZE)
    F3 = np.diag(M[2*SIZE**2:, 2*SIZE**2:]).reshape(SIZE, SIZE)
    F = np.stack([F1, F2, F3])
    print(F.shape)
    F = (F - F.min()) / (F.max() - F.min())
    F = np.rollaxis(F, 0, 3)
    plt.imshow(F)
    plt.axis('off')
    plt.savefig('./video_logs/' + str(idx).zfill(6) + '.png',
                bbox_inches='tight', pad_inches = 0)
    return F


def train_network(train_loader, val_loader, test_loader,
                  num_classes=2, name=None,
                  save_frames=False):


    for idx, batch in enumerate(train_loader):
        inputs, labels = batch
        _, dim = inputs.shape
        break

    # neural_model.Net
    net = Net(dim, num_classes=num_classes)

    params = 0
    for idx, param in enumerate(list(net.parameters())):
        size = 1
        for idx in range(len(param.size())):
            size *= param.size()[idx]
            params += size
    print("NUMBER OF PARAMS: ", params)

    optimizer = torch.optim.SGD(net.parameters(), lr=.1)

    net.to(device)
    num_epochs = 501
    best_val_acc = 0
    best_test_acc = 0
    best_val_loss = np.float("inf")
    best_test_loss = 0

    for i in range(num_epochs):
        if save_frames:
            net.cpu()
            for idx, p in enumerate(net.parameters()):
                if idx == 0:
                    M = p.data.numpy()
            M = M.T @ M
            visualize_M(M, i)
            net.to(device)

        if i == 0 or i == 1:
            net.cpu()
            d = {}
            d['state_dict'] = net.state_dict()
            if name is not None:
                torch.save(d, 'nn_models/' + name + '_trained_nn_' + str(i) + '.pth')
            else:
                torch.save(d, 'nn_models/trained_nn.pth')
            net.to(device)

        train_loss = train_step(net, optimizer, train_loader, save_frames=save_frames)
        val_loss = val_step(net, val_loader)
        test_loss = val_step(net, test_loader)
        train_acc = get_acc(net, train_loader)
        val_acc = get_acc(net, val_loader)
        test_acc = get_acc(net, test_loader)

        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            net.cpu()
            d = {}
            d['state_dict'] = net.state_dict()
            if name is not None:
                torch.save(d, 'nn_models/' + name + '_trained_nn.pth')
            else:
                torch.save(d, 'nn_models/trained_nn.pth')
            net.to(device)

        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            best_test_loss = test_loss

        print("Epoch: ", i,
              "Train Loss: ", train_loss, "Test Loss: ", test_loss,
              "Train Acc: ", train_acc, "Test Acc: ", test_acc,
              "Best Val Acc: ", best_val_acc, "Best Val Loss: ", best_val_loss,
              "Best Test Acc: ", best_test_acc, "Best Test Loss: ", best_test_loss)


def get_data(loader):
    X = []
    y = []
    for idx, batch in enumerate(loader):
        inputs, labels = batch
        X.append(inputs)
        y.append(labels)
    return torch.cat(X, dim=0), torch.cat(y, dim=0)


def train_step(net, optimizer, train_loader, save_frames=False):
    net.train()
    start = time.time()
    train_loss = 0.

    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        targets = labels
        output = net(Variable(inputs))
        target = Variable(targets)
        loss = torch.mean(torch.pow(output - target, 2))
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().data.numpy() * len(inputs)
    end = time.time()
    print("Time: ", end - start)
    train_loss = train_loss / len(train_loader.dataset)
    return train_loss

def val_step(net, val_loader):
    net.eval()
    val_loss = 0.

    for batch_idx, batch in enumerate(val_loader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        targets = labels
        with torch.no_grad():
            output = net(Variable(inputs))
            target = Variable(targets)
        loss = torch.mean(torch.pow(output - target, 2))
        val_loss += loss.cpu().data.numpy() * len(inputs)
    val_loss = val_loss / len(val_loader.dataset)
    return val_loss


def get_acc(net, loader):
    net.eval()
    count = 0
    for batch_idx, batch in enumerate(loader):
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            output = net(Variable(inputs).cuda())
            target = Variable(targets).cuda()

        preds = torch.argmax(output, dim=-1)
        labels = torch.argmax(target, dim=-1)

        count += torch.sum(labels == preds).cpu().data.numpy()
    return count / len(loader.dataset) * 100

# rfm

In [23]:
!pip install hickle==5.0.2



In [32]:
import numpy as np
import torch
from numpy.linalg import solve
# import classic_kernel
from tqdm import tqdm
import hickle

def laplace_kernel_M(pair1, pair2, bandwidth, M):
    # classic_kernel.laplacian_M
    return laplacian_M(pair1, pair2, bandwidth, M)


def get_grads(X, sol, L, P, batch_size=2):
    M = 0.

    num_samples = 20000
    indices = np.random.randint(len(X), size=num_samples)

    if len(X) > len(indices):
        x = X[indices, :]
    else:
        x = X

    K = laplace_kernel_M(X, x, L, P)

    # classic_kernel.euclidean_distances_M
    dist = euclidean_distances_M(X, x, P, squared=False)
    dist = torch.where(dist < 1e-10, torch.zeros(1).float(), dist)

    K = K / dist
    K[K == float("Inf")] = 0.

    a1 = torch.from_numpy(sol.T).float()
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ P).reshape(n, 1, d)
    step1 = a1 @ X1
    del a1, X1
    step1 = step1.reshape(-1, c * d)

    step2 = K.T @ step1
    del step1

    step2 = step2.reshape(-1, c, d)

    a2 = torch.from_numpy(sol).float()
    step3 = (a2 @ K).T

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ P).reshape(m, 1, d)
    step3 = step3 @ x1

    G = (step2 - step3) * -1 / L

    M = 0.

    bs = batch_size
    batches = torch.split(G, bs)
    for i in tqdm(range(len(batches))):
        grad = batches[i]
        if torch.cuda.is_available():
            grad = grad.cuda()
        gradT = torch.transpose(grad, 1, 2)
        M += torch.sum(gradT @ grad, dim=0).cpu()
        del grad, gradT
    torch.cuda.empty_cache()
    M /= len(G)
    M = M.numpy()

    return M


def rfm(train_loader, val_loader, test_loader,
        iters=3, name=None, batch_size=2, reg=1e-3,
        train_acc=False):

    L = 10

    X_train, y_train = get_data(train_loader)
    X_val, y_val = get_data(val_loader)
    X_test, y_test = get_data(test_loader)

    n, d = X_train.shape

    M = np.eye(d, dtype='float32')

    for i in range(iters):
        K_train = laplace_kernel_M(X_train, X_train, L, torch.from_numpy(M)).numpy()
        sol = solve(K_train + reg * np.eye(len(K_train)), y_train).T

        if train_acc:
            preds = (sol @ K_train).T
            y_pred = torch.from_numpy(preds)
            preds = torch.argmax(y_pred, dim=-1)
            labels = torch.argmax(y_train, dim=-1)
            count = torch.sum(labels == preds).numpy()
            print("Round " + str(i) + " Train Acc: ", count / len(labels))

        K_test = laplace_kernel_M(X_train, X_test, L, torch.from_numpy(M)).numpy()
        preds = (sol @ K_test).T
        print("Round " + str(i) + " MSE: ", np.mean(np.square(preds - y_test.numpy())))
        y_pred = torch.from_numpy(preds)
        preds = torch.argmax(y_pred, dim=-1)
        labels = torch.argmax(y_test, dim=-1)
        count = torch.sum(labels == preds).numpy()
        print("Round " + str(i) + " Acc: ", count / len(labels))

        M  = get_grads(X_train, sol, L, torch.from_numpy(M), batch_size=batch_size)
        if name is not None:
            hickle.dump(M, 'saved_Ms/M_' + name + '_' + str(i) + '.h')

    K_train = laplace_kernel_M(X_train, X_train, L, torch.from_numpy(M)).numpy()
    sol = solve(K_train + reg * np.eye(len(K_train)), y_train).T
    K_test = laplace_kernel_M(X_train, X_test, L, torch.from_numpy(M)).numpy()
    preds = (sol @ K_test).T
    mse = np.mean(np.square(preds - y_test.numpy()))
    print("Final MSE: ", mse)
    y_pred = torch.from_numpy(preds)
    preds = torch.argmax(y_pred, dim=-1)
    labels = torch.argmax(y_test, dim=-1)
    count = torch.sum(labels == preds).numpy()
    print(" Final Acc: ", count / len(labels))
    return mse


def get_data(loader):
    X = []
    y = []
    for idx, batch in enumerate(loader):
        inputs, labels = batch
        X.append(inputs)
        y.append(labels)
    return torch.cat(X, dim=0), torch.cat(y, dim=0)

In [25]:
!pip install visdom



# grokking_main

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
# import trainer as t
from torch.utils.data import Dataset
import random
import torch.backends.cudnn as cudnn
# import rfm
import numpy as np
from sklearn.model_selection import train_test_split
from torch.linalg import norm
from random import randint
import visdom
# import eigenpro_rtfm as erfm
import hickle
# import neural_model

vis = visdom.Visdom('http://127.0.0.1', use_incoming_socket=False)
vis.close(env='main')

SEED = 5636
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)

SIZE = 96
h, w = SIZE, SIZE
locationx = np.array([randint(0, h-1) for i in range(10)], dtype='int')
locationy = np.array([randint(0, w-1) for i in range(10)], dtype='int')

shiftx_l = np.array(locationx - 2, dtype='int')
shiftx_r = np.array(locationx + 3, dtype='int')
shifty_l = np.array(locationy - 2, dtype='int')
shifty_r = np.array(locationy + 3, dtype='int')


def one_hot_data(dataset, num_samples=-1):
    labelset = {}
    for i in range(10):
        one_hot = torch.zeros(10)
        one_hot[i] = 1
        labelset[i] = one_hot

    subset = [(ex, label) for idx, (ex, label) in enumerate(dataset) \
              if idx < num_samples and label == 0 or label == 9]

    adjusted = []

    count = 0
    for idx, (ex, label) in enumerate(subset):
        ex[:, 2:7, 7:12] = 0.
        if label == 9:
            count += 1
            ex[:, 2:7, 7:12] = 1.
        if idx < 10:
            vis.image(ex)
        ex = ex.flatten()
        adjusted.append((ex, labelset[label]))
    return adjusted

def split(trainset, p=.8):
    train, val = train_test_split(trainset, train_size=p)
    return train, val

def load_from_net(SIZE=64, path='./nn_models/trained_nn.pth'):
    dim = 3 * SIZE * SIZE
    # neural_model.Net
    net = Net(dim, num_classes=10)

    d = torch.load(path)
    net.load_state_dict(d['state_dict'])
    for idx, p in enumerate(net.parameters()):
        if idx == 0:
            M = p.data.numpy()

    M = M.T @ M
    return M

def main():
    cudnn.benchmark = True
    global SIZE

    transform = transforms.Compose(
        [transforms.ToTensor()
        ])

    path = '~/datasets/'
    trainset = torchvision.datasets.STL10(root=path,
                                          split='train',
                                          transform=transform,
                                          download=True)

    trainset = one_hot_data(trainset, num_samples=500)
    trainset, valset = split(trainset, p=.8)

    print("Train Size: ", len(trainset), "Val Size: ", len(valset))

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024,
                                              shuffle=True, num_workers=2)

    valloader = torch.utils.data.DataLoader(valset, batch_size=100,
                                            shuffle=False, num_workers=1)

    testset = torchvision.datasets.STL10(root=path,
                                         split='test',
                                         transform=transform,
                                         download=True)
    testset = one_hot_data(testset, num_samples=1e10)
    print(len(testset))

    testloader = torch.utils.data.DataLoader(testset, batch_size=1024,
                                             shuffle=False, num_workers=2)

    name = 'grokking'
    # rfm.rfm
    rfm(trainloader, valloader, testloader,
            name=name,
            iters=5,
            train_acc=True, reg=1e-3)

    # trainer.train_network
    train_network(trainloader, valloader, testloader,
                    name=name)

if __name__ == "__main__":
    main()

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 203, in _new_conn
    sock = connection.create_connection(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 85, in create_connection
    raise err
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 791, in urlopen
    response = self._make_request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 497, in _make_request
    conn.request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 395, in request
    self.endheaders()
  File "/usr/lib/python3.10/http/client.py", line 1

Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Files already downloaded and verified
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
---

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 203, in _new_conn
    sock = connection.create_connection(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 85, in create_connection
    raise err
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 791, in urlopen
    response = self._make_request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 497, in _make_request
    conn.request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 395, in request
    self.endheaders()
  File "/usr/lib/python3.10/http/client.py", line 1

Train Size:  442 Val Size:  111
Files already downloaded and verified
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------
Exception in user code:
------------------------------------------------------------


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 203, in _new_conn
    sock = connection.create_connection(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 85, in create_connection
    raise err
  File "/usr/local/lib/python3.10/dist-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 791, in urlopen
    response = self._make_request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connectionpool.py", line 497, in _make_request
    conn.request(
  File "/usr/local/lib/python3.10/dist-packages/urllib3/connection.py", line 395, in request
    self.endheaders()
  File "/usr/lib/python3.10/http/client.py", line 1

1600
Round 0 Train Acc:  1.0
Round 0 MSE:  0.06461313481906976
Round 0 Acc:  0.55875


  0%|          | 0/221 [00:00<?, ?it/s]