In [1]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch.nn as nn
from torchvision import transforms
import torchvision
import random
from tqdm import tqdm
from PIL import Image

In [2]:
def WaveletTransformAxisY(batch_img):
    odd_img  = batch_img[:,0::2]
    even_img = batch_img[:,1::2]
    L = (odd_img + even_img) / 2.0
    H = torch.abs(odd_img - even_img)

    return L, H

In [3]:
def WaveletTransformAxisX(batch_img):
    # transpose + fliplr
    tmp_batch = torch.permute(batch_img, (0, 2, 1))
    tmp_batch = torch.fliplr(tmp_batch)
    _dst_L, _dst_H = WaveletTransformAxisY(tmp_batch)
    # transpose + flipud
    dst_L = torch.permute(_dst_L, [0, 2, 1])
    dst_L = torch.flipud(dst_L)
    dst_H = torch.permute(_dst_H, [0, 2, 1])
    dst_H = torch.flipud(dst_H)

    return dst_L, dst_H

In [4]:
def Wavelet(batch_image):
    r = batch_image[:,0]
    g = batch_image[:,1]
    b = batch_image[:,2]

    # level 1 decomposition
    wavelet_L, wavelet_H = WaveletTransformAxisY(r)
    r_wavelet_LL, r_wavelet_LH = WaveletTransformAxisX(wavelet_L)
    r_wavelet_HL, r_wavelet_HH = WaveletTransformAxisX(wavelet_H)

    wavelet_L, wavelet_H = WaveletTransformAxisY(g)
    g_wavelet_LL, g_wavelet_LH = WaveletTransformAxisX(wavelet_L)
    g_wavelet_HL, g_wavelet_HH = WaveletTransformAxisX(wavelet_H)

    wavelet_L, wavelet_H = WaveletTransformAxisY(b)
    b_wavelet_LL, b_wavelet_LH = WaveletTransformAxisX(wavelet_L)
    b_wavelet_HL, b_wavelet_HH = WaveletTransformAxisX(wavelet_H)

    wavelet_data = [r_wavelet_LL, r_wavelet_LH, r_wavelet_HL, r_wavelet_HH,
                    g_wavelet_LL, g_wavelet_LH, g_wavelet_HL, g_wavelet_HH,
                    b_wavelet_LL, b_wavelet_LH, b_wavelet_HL, b_wavelet_HH]
    transform_batch = torch.stack(wavelet_data, axis=1)

    # level 2 decomposition
    wavelet_L2, wavelet_H2 = WaveletTransformAxisY(r_wavelet_LL)
    r_wavelet_LL2, r_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
    r_wavelet_HL2, r_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)

    wavelet_L2, wavelet_H2 = WaveletTransformAxisY(g_wavelet_LL)
    g_wavelet_LL2, g_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
    g_wavelet_HL2, g_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)

    wavelet_L2, wavelet_H2 = WaveletTransformAxisY(b_wavelet_LL)
    b_wavelet_LL2, b_wavelet_LH2 = WaveletTransformAxisX(wavelet_L2)
    b_wavelet_HL2, b_wavelet_HH2 = WaveletTransformAxisX(wavelet_H2)


    wavelet_data_l2 = [r_wavelet_LL2, r_wavelet_LH2, r_wavelet_HL2, r_wavelet_HH2,
                       g_wavelet_LL2, g_wavelet_LH2, g_wavelet_HL2, g_wavelet_HH2,
                       b_wavelet_LL2, b_wavelet_LH2, b_wavelet_HL2, b_wavelet_HH2]
    transform_batch_l2 = torch.stack(wavelet_data_l2, dim=1)

    # level 3 decomposition
    wavelet_L3, wavelet_H3 = WaveletTransformAxisY(r_wavelet_LL2)
    r_wavelet_LL3, r_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
    r_wavelet_HL3, r_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

    wavelet_L3, wavelet_H3 = WaveletTransformAxisY(g_wavelet_LL2)
    g_wavelet_LL3, g_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
    g_wavelet_HL3, g_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

    wavelet_L3, wavelet_H3 = WaveletTransformAxisY(b_wavelet_LL2)
    b_wavelet_LL3, b_wavelet_LH3 = WaveletTransformAxisX(wavelet_L3)
    b_wavelet_HL3, b_wavelet_HH3 = WaveletTransformAxisX(wavelet_H3)

    wavelet_data_l3 = [r_wavelet_LL3, r_wavelet_LH3, r_wavelet_HL3, r_wavelet_HH3,
                       g_wavelet_LL3, g_wavelet_LH3, g_wavelet_HL3, g_wavelet_HH3,
                       b_wavelet_LL3, b_wavelet_LH3, b_wavelet_HL3, b_wavelet_HH3]
    transform_batch_l3 = torch.stack(wavelet_data_l3, dim=1)

    # level 4 decomposition
    wavelet_L4, wavelet_H4 = WaveletTransformAxisY(r_wavelet_LL3)
    r_wavelet_LL4, r_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
    r_wavelet_HL4, r_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)

    wavelet_L4, wavelet_H4 = WaveletTransformAxisY(g_wavelet_LL3)
    g_wavelet_LL4, g_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
    g_wavelet_HL4, g_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)

    wavelet_L4, wavelet_H4 = WaveletTransformAxisY(b_wavelet_LL3)
    b_wavelet_LL4, b_wavelet_LH4 = WaveletTransformAxisX(wavelet_L4)
    b_wavelet_HL4, b_wavelet_HH4 = WaveletTransformAxisX(wavelet_H4)


    wavelet_data_l4 = [r_wavelet_LL4, r_wavelet_LH4, r_wavelet_HL4, r_wavelet_HH4,
                       g_wavelet_LL4, g_wavelet_LH4, g_wavelet_HL4, g_wavelet_HH4,
                       b_wavelet_LL4, b_wavelet_LH4, b_wavelet_HL4, b_wavelet_HH4]
    transform_batch_l4 = torch.stack(wavelet_data_l4, dim=1)

    return [transform_batch, transform_batch_l2, transform_batch_l3, transform_batch_l4]

In [5]:
class Wavelet_Model(torch.nn.Module):
    def __init__(self, classes=10):
        super(Wavelet_Model, self).__init__()
        self.conv_1 = nn.Conv2d(12, 64, kernel_size=(3, 3), padding=1)
        self.norm_1 = nn.BatchNorm2d(64)
        self.relu_1 = nn.ReLU()

        self.conv_1_2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.norm_1_2 = nn.BatchNorm2d(64)
        self.relu_1_2 = nn.ReLU()
        #################################################################################################
        self.conv_a = nn.Conv2d(12, 64, kernel_size=(3, 3), padding=1)
        self.norm_a = nn.BatchNorm2d(64)
        self.relu_a = nn.ReLU()
        #################################################################################################
        self.conv_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)
        self.norm_2 = nn.BatchNorm2d(128)
        self.relu_2 = nn.ReLU()

        self.conv_2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.norm_2_2 = nn.BatchNorm2d(128)
        self.relu_2_2 = nn.ReLU()
        #################################################################################################
        self.conv_b = nn.Conv2d(12, 128, kernel_size=(3, 3), padding=1)
        self.norm_b = nn.BatchNorm2d(128)
        self.relu_b = nn.ReLU()

        self.conv_b_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)
        self.norm_b_2 = nn.BatchNorm2d(128)
        self.relu_b_2 = nn.ReLU()
        #################################################################################################
        self.conv_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1)
        self.norm_3 = nn.BatchNorm2d(256)
        self.relu_3 = nn.ReLU()

        self.conv_3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.norm_3_2 = nn.BatchNorm2d(256)
        self.relu_3_2 = nn.ReLU()
        #################################################################################################
        self.conv_c = nn.Conv2d(12, 256, kernel_size=(3, 3), padding=1)
        self.norm_c = nn.BatchNorm2d(256)
        self.relu_c = nn.ReLU()

        self.conv_c_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1)
        self.norm_c_2 = nn.BatchNorm2d(256)
        self.relu_c_2 = nn.ReLU()

        self.conv_c_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1)
        self.norm_c_3 = nn.BatchNorm2d(256)
        self.relu_c_3 = nn.ReLU()
        #################################################################################################
        self.conv_4 = nn.Conv2d(512, 256, kernel_size=(3, 3), padding=1)
        self.norm_4 = nn.BatchNorm2d(256)
        self.relu_4 = nn.ReLU()

        self.conv_4_2 = nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.norm_4_2 = nn.BatchNorm2d(128)
        self.relu_4_2 = nn.ReLU()
        #################################################################################################
        self.conv_5 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)
        self.norm_5 = nn.BatchNorm2d(128)
        self.relu_5 = nn.ReLU()

        self.pool_5 = nn.AvgPool2d(kernel_size=(7,7), stride=1, padding=1)
        self.flat_5 = nn.Flatten()

        self.fc_5 = nn.Linear(1152, 2048)
        self.norm_5_1 = nn.BatchNorm1d(2048)
        self.relu_5_1 = nn.ReLU()
        self.drop_5 = nn.Dropout(0.5)
        #################################################################################################
        self.fc_6 = nn.Linear(2048, classes)
        self.norm_6 = nn.BatchNorm1d(classes)
        self.relu_6 = nn.ReLU()
        self.drop_6 = nn.Dropout(0.5)
        #################################################################################################
        self.output_fc = nn.Linear(classes, classes)

    def forward(self, x):
        input_l1, input_l2, input_l3, input_l4 = Wavelet(x)
        #################################################################################################
        # print('input shape: ', input_l1.shape)
        out_1 = self.conv_1(input_l1)
        # print('conv_1 output shape: ', out_1.shape)
        out_1 = self.norm_1(out_1)
        out_1 = self.relu_1(out_1)

        out_1 = self.conv_1_2(out_1)
        # print('conv_1_2 output shape: ', out_1.shape)
        out_1 = self.norm_1_2(out_1)
        out_1 = self.relu_1_2(out_1)
        #################################################################################################
        out_2 = self.conv_a(input_l2)
        # print('conv_a output shape: ', out_2.shape)
        out_2 = self.norm_a(out_2)
        out_2 = self.relu_a(out_2)

        cat_2 = torch.cat((out_1, out_2), 1)
        # print('concatenate result: ', cat_2.shape)
        out_2 = self.conv_2(cat_2)
        # print('conv_2 output shape: ', out_2.shape)
        out_2 = self.norm_2(out_2)
        out_2 = self.relu_2(out_2)

        out_2 = self.conv_2_2(out_2)
        # print('conv_2_2 output shape: ', out_2.shape)
        out_2 = self.norm_2_2(out_2)
        out_2 = self.relu_2_2(out_2)
        #################################################################################################
        out_3 = self.conv_b(input_l3)
        # print('conv_b output shape: ', out_3.shape)
        out_3 = self.norm_b(out_3)
        out_3 = self.relu_b(out_3)

        out_3 = self.conv_b_2(out_3)
        # print('conv_b_2 output shape: ', out_3.shape)
        out_3 = self.norm_b_2(out_3)
        out_3 = self.relu_b_2(out_3)
        #################################################################################################
        cat_3 = torch.cat((out_2, out_3), 1)
        # print('concatenate result: ', cat_3.shape)
        out_3 = self.conv_3(cat_3)
        # print('conv_3 output shape: ', out_3.shape)
        out_3 = self.norm_3(out_3)
        out_3 = self.relu_3(out_3)

        out_3 = self.conv_3_2(out_3)
        # print('conv_3_2 output shape: ', out_3.shape)
        out_3 = self.norm_3_2(out_3)
        out_3 = self.relu_3_2(out_3)
        #################################################################################################
        out_4 = self.conv_c(input_l4)
        # print('conv_c output shape: ', out_4.shape)
        out_4 = self.norm_c(out_4)
        out_4 = self.relu_c(out_4)

        out_4 = self.conv_c_2(out_4)
        # print('conv_c_2 output shape: ', out_4.shape)
        out_4 = self.norm_c_2(out_4)
        out_4 = self.relu_c_2(out_4)

        out_4 = self.conv_c_3(out_4)
        # print('conv_c_3 output shape: ', out_4.shape)
        out_4 = self.norm_c_3(out_4)
        out_4 = self.relu_c_3(out_4)
        #################################################################################################
        cat_4 = torch.cat((out_3, out_4), 1)
        # print('concatenate result: ', cat_4.shape)
        out_4 = self.conv_4(cat_4)
        # print('conv_4 output shape: ', out_4.shape)
        out_4 = self.norm_4(out_4)
        out_4 = self.relu_4(out_4)

        out_4 = self.conv_4_2(out_4)
        # print('conv_4_2 output shape: ', out_4.shape)
        out_4 = self.norm_4_2(out_4)
        out_4 = self.relu_4_2(out_4)
        #################################################################################################
        out_5 = self.conv_5(out_4)
        # print('conv_5 output shape: ', out_5.shape)
        out_5 = self.norm_5(out_5)
        out_5 = self.relu_5(out_5)

        out_5 = self.pool_5(out_5)
        out_5 = self.flat_5(out_5)
        #################################################################################################
        out_5 = self.fc_5(out_5)
        # print('fc_5 output shape: ', out_5.shape)
        out_5 = self.norm_5_1(out_5)
        out_5 = self.relu_5_1(out_5)
        out_5 = self.drop_5(out_5)

        out_6 = self.fc_6(out_5)
        # print('fc6 output shape: ', out_6.shape)
        out_6 = self.norm_6(out_6)
        out_6 = self.relu_6(out_6)
        out_6 = self.drop_6(out_6)
        #################################################################################################
        output = self.output_fc(out_6)

        return output

In [6]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmindw96[0m (use `wandb login --relogin` to force relogin)


True

In [7]:
batch_size  = 512
random_seed = 888
random.seed(random_seed)
torch.manual_seed(random_seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [23]:
config = dict(
    epochs=300,
    classes=58,
    batch_size=512,
    learning_rate=0.0005,
    dataset="KTH-TIPS2+DTD",
    architecture="Wavelet_CNN",
    data_path = './dataset/'
)

In [24]:
def model_pipeline(hyperparameters):

    # tell wandb to get started
    with wandb.init(project="Texture Classification", config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config

        # make the model, data, and optimization problem
        model, train_loader, test_loader, criterion, optimizer = make(config)

        # and use them to train the model
        train(model, train_loader, criterion, optimizer, config)

        # and test its final performance
        # skip test
        # test(model, test_loader)

    return model

In [25]:
def make(config):
    # Make the data
    datasets = get_data(config.data_path)
    train_loader = make_loader(datasets, batch_size=config.batch_size)
    # valid_loader = make_loader(datasets['valid'], batch_size=config.batch_size)

    # Make the model
    model = Wavelet_Model(config.classes).to(device)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.learning_rate)

    return model, train_loader, criterion, optimizer

In [26]:
def get_data(data_path=''):
    texture_dataset = torchvision.datasets.ImageFolder(
        data_path,
        transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]))
    #  equiv to slicing with [::slice]
    # train_idx, val_idx = train_test_split(list(range(len(texture_dataset))), test_size=0.2, random_state=random_seed)
    # datasets = {}
    # datasets['train'] = Subset(texture_dataset, train_idx)
    # datasets['valid'] = Subset(texture_dataset, val_idx)

    return texture_dataset


def make_loader(dataset, batch_size):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         pin_memory=True, num_workers=20)
    return loader

In [27]:
def train(model, loader, criterion, optimizer, config):
    # tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, criterion, log="all", log_freq=10)

    # Run training and track with wandb
    total_batches = len(loader) * config.epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    for epoch in tqdm(range(config.epochs)):
        for _, (images, labels) in enumerate(loader):

            loss, accuracy = train_batch(images, labels, model, optimizer, criterion)
            example_ct +=  len(images)
            batch_ct += 1

            # Report metrics every 25th batch
            if ((batch_ct + 1) % 25) == 0:
                train_log(loss, example_ct, epoch, accuracy)

    torch.onnx.export(model, images, "model.onnx")
    wandb.save("model.onnx")


def train_batch(images, labels, model, optimizer, criterion):
    images, labels = images.to(device), labels.to(device)
    correct, total = 0, 0
    # Forward pass ➡
    outputs = model(images)
    loss = criterion(outputs, labels)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    accuracy = correct / total

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()

    return loss, accuracy

In [28]:
def train_log(loss, example_ct, epoch, accuracy):
    loss = float(loss)

    # where the magic happens
    wandb.log({"epoch": epoch, "loss": loss, "accuracy": accuracy}, step=example_ct)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")

In [29]:
def test(model, test_loader):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        correct, total = 0, 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {total} " +
              f"test images: {100 * correct / total}%")

        wandb.log({"test_accuracy": correct / total})

    # Save the model in the exchangeable ONNX format
    torch.onnx.export(model, images, "model.onnx")
    wandb.save("model.onnx")

In [30]:
model = model_pipeline(config)

Wavelet_Model(
  (conv_1): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_1): ReLU()
  (conv_1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (norm_1_2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_1_2): ReLU()
  (conv_a): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_a): ReLU()
  (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_2): ReLU()
  (conv_2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (norm_2_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_2_2): ReLU()
  (conv_b): Conv2d(12,

  0%|          | 1/300 [00:33<2:44:38, 33.04s/it]

Loss after 11897 examples: 3.280


  1%|          | 2/300 [01:05<2:41:46, 32.57s/it]

Loss after 24306 examples: 2.844


  1%|▏         | 4/300 [02:11<2:42:25, 32.92s/it]

Loss after 36324 examples: 2.627


  2%|▏         | 5/300 [02:43<2:40:09, 32.57s/it]

Loss after 48733 examples: 2.505


  2%|▏         | 7/300 [03:45<2:34:38, 31.67s/it]

Loss after 60751 examples: 2.598


  3%|▎         | 8/300 [04:14<2:30:51, 31.00s/it]

Loss after 73160 examples: 2.269


  3%|▎         | 10/300 [05:15<2:28:45, 30.78s/it]

Loss after 85178 examples: 2.323


  4%|▎         | 11/300 [05:46<2:27:13, 30.57s/it]

Loss after 97587 examples: 2.175


  4%|▍         | 13/300 [06:46<2:25:12, 30.36s/it]

Loss after 109605 examples: 2.238


  5%|▍         | 14/300 [07:16<2:24:57, 30.41s/it]

Loss after 122014 examples: 2.119


  5%|▌         | 16/300 [08:17<2:22:59, 30.21s/it]

Loss after 134032 examples: 2.140


  6%|▌         | 17/300 [08:46<2:20:57, 29.88s/it]

Loss after 146441 examples: 2.070


  6%|▋         | 19/300 [09:46<2:20:21, 29.97s/it]

Loss after 158459 examples: 1.913


  7%|▋         | 20/300 [10:15<2:19:04, 29.80s/it]

Loss after 170868 examples: 1.915


  7%|▋         | 21/300 [10:46<2:19:38, 30.03s/it]

Loss after 182886 examples: 1.905


  8%|▊         | 23/300 [11:46<2:19:11, 30.15s/it]

Loss after 195295 examples: 1.819


  8%|▊         | 24/300 [12:17<2:19:42, 30.37s/it]

Loss after 207704 examples: 1.999


  9%|▊         | 26/300 [13:17<2:17:22, 30.08s/it]

Loss after 219722 examples: 1.933


  9%|▉         | 27/300 [13:48<2:17:21, 30.19s/it]

Loss after 232131 examples: 1.778


 10%|▉         | 29/300 [14:48<2:16:55, 30.32s/it]

Loss after 244149 examples: 1.738


 10%|█         | 30/300 [15:19<2:17:08, 30.48s/it]

Loss after 256558 examples: 1.742


 11%|█         | 32/300 [16:19<2:15:30, 30.34s/it]

Loss after 268576 examples: 1.773


 11%|█         | 33/300 [16:49<2:14:11, 30.15s/it]

Loss after 280985 examples: 1.809


 12%|█▏        | 35/300 [17:58<2:23:18, 32.45s/it]

Loss after 293003 examples: 1.680


 12%|█▏        | 36/300 [18:32<2:24:38, 32.87s/it]

Loss after 305412 examples: 1.447


 13%|█▎        | 38/300 [19:42<2:27:22, 33.75s/it]

Loss after 317430 examples: 1.695


 13%|█▎        | 39/300 [20:15<2:26:22, 33.65s/it]

Loss after 329839 examples: 1.570


 14%|█▎        | 41/300 [21:23<2:25:25, 33.69s/it]

Loss after 341857 examples: 1.450


 14%|█▍        | 42/300 [21:56<2:24:41, 33.65s/it]

Loss after 354266 examples: 1.351


 15%|█▍        | 44/300 [23:07<2:26:29, 34.33s/it]

Loss after 366284 examples: 1.344


 15%|█▌        | 45/300 [23:40<2:24:26, 33.99s/it]

Loss after 378693 examples: 1.407


 15%|█▌        | 46/300 [24:20<2:31:21, 35.75s/it]

Loss after 390711 examples: 1.209


 16%|█▌        | 48/300 [25:28<2:26:49, 34.96s/it]

Loss after 403120 examples: 1.365


 16%|█▋        | 49/300 [26:04<2:26:44, 35.08s/it]

Loss after 415529 examples: 1.319


 17%|█▋        | 51/300 [27:15<2:26:54, 35.40s/it]

Loss after 427547 examples: 1.300


 17%|█▋        | 52/300 [27:50<2:25:40, 35.24s/it]

Loss after 439956 examples: 1.254


 18%|█▊        | 54/300 [29:01<2:24:36, 35.27s/it]

Loss after 451974 examples: 1.267


 18%|█▊        | 55/300 [29:35<2:23:03, 35.03s/it]

Loss after 464383 examples: 1.193


 19%|█▉        | 57/300 [30:44<2:20:49, 34.77s/it]

Loss after 476401 examples: 1.177


 19%|█▉        | 58/300 [31:19<2:20:06, 34.74s/it]

Loss after 488810 examples: 1.076


 20%|██        | 60/300 [32:29<2:19:54, 34.98s/it]

Loss after 500828 examples: 1.198


 20%|██        | 61/300 [33:05<2:20:28, 35.27s/it]

Loss after 513237 examples: 1.217


 21%|██        | 63/300 [34:15<2:18:45, 35.13s/it]

Loss after 525255 examples: 1.100


 21%|██▏       | 64/300 [34:49<2:16:13, 34.63s/it]

Loss after 537664 examples: 1.071


 22%|██▏       | 66/300 [35:57<2:14:31, 34.49s/it]

Loss after 549682 examples: 0.987


 22%|██▏       | 67/300 [36:33<2:15:13, 34.82s/it]

Loss after 562091 examples: 0.841


 23%|██▎       | 69/300 [37:41<2:13:02, 34.56s/it]

Loss after 574109 examples: 0.979


 23%|██▎       | 70/300 [38:15<2:12:05, 34.46s/it]

Loss after 586518 examples: 0.987


 24%|██▎       | 71/300 [38:51<2:12:11, 34.64s/it]

Loss after 598536 examples: 1.031


 24%|██▍       | 73/300 [40:00<2:11:01, 34.63s/it]

Loss after 610945 examples: 0.934


 25%|██▍       | 74/300 [40:35<2:11:03, 34.80s/it]

Loss after 623354 examples: 0.779


 25%|██▌       | 76/300 [41:45<2:10:33, 34.97s/it]

Loss after 635372 examples: 0.822


 26%|██▌       | 77/300 [42:21<2:10:09, 35.02s/it]

Loss after 647781 examples: 0.792


 26%|██▋       | 79/300 [43:29<2:07:15, 34.55s/it]

Loss after 659799 examples: 0.781


 27%|██▋       | 80/300 [44:04<2:07:08, 34.68s/it]

Loss after 672208 examples: 0.750


 27%|██▋       | 82/300 [45:13<2:05:36, 34.57s/it]

Loss after 684226 examples: 0.752


 28%|██▊       | 83/300 [45:47<2:04:56, 34.54s/it]

Loss after 696635 examples: 0.760


 28%|██▊       | 85/300 [46:56<2:03:35, 34.49s/it]

Loss after 708653 examples: 0.703


 29%|██▊       | 86/300 [47:31<2:03:59, 34.77s/it]

Loss after 721062 examples: 0.655


 29%|██▉       | 88/300 [48:40<2:02:14, 34.60s/it]

Loss after 733080 examples: 0.555


 30%|██▉       | 89/300 [49:15<2:01:25, 34.53s/it]

Loss after 745489 examples: 0.615


 30%|███       | 91/300 [50:24<2:00:00, 34.45s/it]

Loss after 757507 examples: 0.640


 31%|███       | 92/300 [50:58<1:59:23, 34.44s/it]

Loss after 769916 examples: 0.586


 31%|███▏      | 94/300 [52:06<1:58:01, 34.38s/it]

Loss after 781934 examples: 0.591


 32%|███▏      | 95/300 [52:41<1:57:43, 34.46s/it]

Loss after 794343 examples: 0.474


 32%|███▏      | 96/300 [53:15<1:56:48, 34.36s/it]

Loss after 806361 examples: 0.509


 33%|███▎      | 98/300 [54:23<1:54:58, 34.15s/it]

Loss after 818770 examples: 0.470


 33%|███▎      | 99/300 [54:57<1:54:22, 34.14s/it]

Loss after 831179 examples: 0.387


 34%|███▎      | 101/300 [56:08<1:55:33, 34.84s/it]

Loss after 843197 examples: 0.483


 34%|███▍      | 102/300 [56:42<1:54:22, 34.66s/it]

Loss after 855606 examples: 0.439


 35%|███▍      | 104/300 [57:52<1:53:18, 34.69s/it]

Loss after 867624 examples: 0.528


 35%|███▌      | 105/300 [58:26<1:52:18, 34.56s/it]

Loss after 880033 examples: 0.433


 36%|███▌      | 107/300 [59:35<1:51:21, 34.62s/it]

Loss after 892051 examples: 0.350


 36%|███▌      | 108/300 [1:00:11<1:51:38, 34.89s/it]

Loss after 904460 examples: 0.380


 37%|███▋      | 110/300 [1:01:20<1:50:23, 34.86s/it]

Loss after 916478 examples: 0.357


 37%|███▋      | 111/300 [1:01:55<1:49:48, 34.86s/it]

Loss after 928887 examples: 0.405


 38%|███▊      | 113/300 [1:03:04<1:48:05, 34.68s/it]

Loss after 940905 examples: 0.354


 38%|███▊      | 114/300 [1:03:37<1:46:13, 34.27s/it]

Loss after 953314 examples: 0.359


 39%|███▊      | 116/300 [1:04:47<1:45:51, 34.52s/it]

Loss after 965332 examples: 0.440


 39%|███▉      | 117/300 [1:05:22<1:45:44, 34.67s/it]

Loss after 977741 examples: 0.296


 40%|███▉      | 119/300 [1:06:32<1:45:05, 34.84s/it]

Loss after 989759 examples: 0.283


 40%|████      | 120/300 [1:07:06<1:43:49, 34.61s/it]

Loss after 1002168 examples: 0.366


 40%|████      | 121/300 [1:07:40<1:42:27, 34.34s/it]

Loss after 1014186 examples: 0.397


 41%|████      | 123/300 [1:08:50<1:42:03, 34.59s/it]

Loss after 1026595 examples: 0.373


 41%|████▏     | 124/300 [1:09:23<1:40:27, 34.25s/it]

Loss after 1039004 examples: 0.277


 42%|████▏     | 126/300 [1:10:32<1:39:17, 34.24s/it]

Loss after 1051022 examples: 0.281


 42%|████▏     | 127/300 [1:11:06<1:38:46, 34.25s/it]

Loss after 1063431 examples: 0.220


 43%|████▎     | 129/300 [1:12:15<1:38:05, 34.42s/it]

Loss after 1075449 examples: 0.263


 43%|████▎     | 130/300 [1:12:50<1:38:00, 34.59s/it]

Loss after 1087858 examples: 0.193


 44%|████▍     | 132/300 [1:13:57<1:35:35, 34.14s/it]

Loss after 1099876 examples: 0.240


 44%|████▍     | 133/300 [1:14:31<1:34:51, 34.08s/it]

Loss after 1112285 examples: 0.219


 45%|████▌     | 135/300 [1:15:41<1:34:38, 34.42s/it]

Loss after 1124303 examples: 0.304


 45%|████▌     | 136/300 [1:16:15<1:34:25, 34.54s/it]

Loss after 1136712 examples: 0.241


 46%|████▌     | 138/300 [1:17:25<1:33:30, 34.63s/it]

Loss after 1148730 examples: 0.184


 46%|████▋     | 139/300 [1:17:59<1:32:35, 34.50s/it]

Loss after 1161139 examples: 0.193


 47%|████▋     | 141/300 [1:19:07<1:30:51, 34.29s/it]

Loss after 1173157 examples: 0.196


 47%|████▋     | 142/300 [1:19:42<1:30:14, 34.27s/it]

Loss after 1185566 examples: 0.143


 48%|████▊     | 144/300 [1:20:52<1:30:15, 34.71s/it]

Loss after 1197584 examples: 0.220


 48%|████▊     | 145/300 [1:21:27<1:30:01, 34.85s/it]

Loss after 1209993 examples: 0.172


 49%|████▊     | 146/300 [1:22:01<1:29:08, 34.73s/it]

Loss after 1222011 examples: 0.176


 49%|████▉     | 148/300 [1:23:11<1:28:00, 34.74s/it]

Loss after 1234420 examples: 0.161


 50%|████▉     | 149/300 [1:23:45<1:27:06, 34.61s/it]

Loss after 1246829 examples: 0.193


 50%|█████     | 151/300 [1:24:54<1:25:36, 34.48s/it]

Loss after 1258847 examples: 0.164


 51%|█████     | 152/300 [1:25:30<1:25:51, 34.81s/it]

Loss after 1271256 examples: 0.138


 51%|█████▏    | 154/300 [1:26:38<1:23:36, 34.36s/it]

Loss after 1283274 examples: 0.220


 52%|█████▏    | 155/300 [1:27:13<1:23:28, 34.54s/it]

Loss after 1295683 examples: 0.109


 52%|█████▏    | 157/300 [1:28:22<1:22:02, 34.42s/it]

Loss after 1307701 examples: 0.128


 53%|█████▎    | 158/300 [1:28:56<1:21:38, 34.49s/it]

Loss after 1320110 examples: 0.168


 53%|█████▎    | 160/300 [1:30:07<1:21:34, 34.96s/it]

Loss after 1332128 examples: 0.257


 54%|█████▎    | 161/300 [1:30:42<1:20:59, 34.96s/it]

Loss after 1344537 examples: 0.233


 54%|█████▍    | 163/300 [1:31:50<1:19:01, 34.61s/it]

Loss after 1356555 examples: 0.242


 55%|█████▍    | 164/300 [1:32:24<1:17:55, 34.38s/it]

Loss after 1368964 examples: 0.237


 55%|█████▌    | 166/300 [1:33:34<1:17:21, 34.64s/it]

Loss after 1380982 examples: 0.155


 56%|█████▌    | 167/300 [1:34:09<1:17:07, 34.79s/it]

Loss after 1393391 examples: 0.126


 56%|█████▋    | 169/300 [1:35:20<1:16:25, 35.00s/it]

Loss after 1405409 examples: 0.148


 57%|█████▋    | 170/300 [1:35:55<1:15:51, 35.01s/it]

Loss after 1417818 examples: 0.143


 57%|█████▋    | 171/300 [1:36:29<1:14:57, 34.86s/it]

Loss after 1429836 examples: 0.107


 58%|█████▊    | 173/300 [1:37:37<1:13:01, 34.50s/it]

Loss after 1442245 examples: 0.088


 58%|█████▊    | 174/300 [1:38:12<1:12:38, 34.59s/it]

Loss after 1454654 examples: 0.089


 59%|█████▊    | 176/300 [1:39:21<1:11:05, 34.40s/it]

Loss after 1466672 examples: 0.124


 59%|█████▉    | 177/300 [1:39:55<1:10:36, 34.44s/it]

Loss after 1479081 examples: 0.114


 60%|█████▉    | 179/300 [1:41:05<1:10:08, 34.78s/it]

Loss after 1491099 examples: 0.188


 60%|██████    | 180/300 [1:41:41<1:09:54, 34.96s/it]

Loss after 1503508 examples: 0.154


 61%|██████    | 182/300 [1:42:51<1:08:45, 34.96s/it]

Loss after 1515526 examples: 0.105


 61%|██████    | 183/300 [1:43:25<1:07:57, 34.85s/it]

Loss after 1527935 examples: 0.096


 62%|██████▏   | 185/300 [1:44:35<1:06:56, 34.92s/it]

Loss after 1539953 examples: 0.101


 62%|██████▏   | 186/300 [1:45:10<1:06:24, 34.95s/it]

Loss after 1552362 examples: 0.079


 63%|██████▎   | 188/300 [1:46:19<1:04:39, 34.64s/it]

Loss after 1564380 examples: 0.099


 63%|██████▎   | 189/300 [1:46:53<1:03:36, 34.38s/it]

Loss after 1576789 examples: 0.097


 64%|██████▎   | 191/300 [1:48:02<1:02:33, 34.44s/it]

Loss after 1588807 examples: 0.090


 64%|██████▍   | 192/300 [1:48:36<1:01:44, 34.30s/it]

Loss after 1601216 examples: 0.157


 65%|██████▍   | 194/300 [1:49:44<1:00:18, 34.14s/it]

Loss after 1613234 examples: 0.095


 65%|██████▌   | 195/300 [1:50:18<59:57, 34.26s/it]  

Loss after 1625643 examples: 0.116


 65%|██████▌   | 196/300 [1:50:53<59:27, 34.31s/it]

Loss after 1637661 examples: 0.153


 66%|██████▌   | 198/300 [1:52:01<58:08, 34.20s/it]

Loss after 1650070 examples: 0.077


 66%|██████▋   | 199/300 [1:52:35<57:42, 34.28s/it]

Loss after 1662479 examples: 0.079


 67%|██████▋   | 201/300 [1:53:45<57:09, 34.65s/it]

Loss after 1674497 examples: 0.125


 67%|██████▋   | 202/300 [1:54:19<56:17, 34.46s/it]

Loss after 1686906 examples: 0.105


 68%|██████▊   | 204/300 [1:55:29<55:30, 34.70s/it]

Loss after 1698924 examples: 0.180


 68%|██████▊   | 205/300 [1:56:03<54:32, 34.45s/it]

Loss after 1711333 examples: 0.105


 69%|██████▉   | 207/300 [1:57:11<53:00, 34.20s/it]

Loss after 1723351 examples: 0.147


 69%|██████▉   | 208/300 [1:57:46<52:47, 34.43s/it]

Loss after 1735760 examples: 0.107


 70%|███████   | 210/300 [1:58:55<51:27, 34.31s/it]

Loss after 1747778 examples: 0.110


 70%|███████   | 211/300 [1:59:28<50:43, 34.19s/it]

Loss after 1760187 examples: 0.078


 71%|███████   | 213/300 [2:00:39<50:20, 34.72s/it]

Loss after 1772205 examples: 0.058


 71%|███████▏  | 214/300 [2:01:14<49:57, 34.86s/it]

Loss after 1784614 examples: 0.072


 72%|███████▏  | 216/300 [2:02:23<48:44, 34.81s/it]

Loss after 1796632 examples: 0.087


 72%|███████▏  | 217/300 [2:02:58<48:10, 34.83s/it]

Loss after 1809041 examples: 0.076


 73%|███████▎  | 219/300 [2:04:07<46:43, 34.62s/it]

Loss after 1821059 examples: 0.091


 73%|███████▎  | 220/300 [2:04:41<46:07, 34.59s/it]

Loss after 1833468 examples: 0.095


 74%|███████▎  | 221/300 [2:05:17<45:47, 34.77s/it]

Loss after 1845486 examples: 0.109


 74%|███████▍  | 223/300 [2:06:25<44:12, 34.45s/it]

Loss after 1857895 examples: 0.153


 75%|███████▍  | 224/300 [2:07:00<43:50, 34.62s/it]

Loss after 1870304 examples: 0.141


 75%|███████▌  | 226/300 [2:08:09<42:28, 34.44s/it]

Loss after 1882322 examples: 0.120


 76%|███████▌  | 227/300 [2:08:44<42:04, 34.58s/it]

Loss after 1894731 examples: 0.117


 76%|███████▋  | 229/300 [2:09:54<41:17, 34.90s/it]

Loss after 1906749 examples: 0.134


 77%|███████▋  | 230/300 [2:10:29<40:49, 34.99s/it]

Loss after 1919158 examples: 0.073


 77%|███████▋  | 232/300 [2:11:38<39:18, 34.69s/it]

Loss after 1931176 examples: 0.114


 78%|███████▊  | 233/300 [2:12:14<39:01, 34.94s/it]

Loss after 1943585 examples: 0.088


 78%|███████▊  | 235/300 [2:13:24<37:55, 35.01s/it]

Loss after 1955603 examples: 0.076


 79%|███████▊  | 236/300 [2:13:58<37:06, 34.79s/it]

Loss after 1968012 examples: 0.081


 79%|███████▉  | 238/300 [2:15:07<35:50, 34.69s/it]

Loss after 1980030 examples: 0.108


 80%|███████▉  | 239/300 [2:15:41<35:01, 34.46s/it]

Loss after 1992439 examples: 0.135


 80%|████████  | 241/300 [2:16:51<34:13, 34.80s/it]

Loss after 2004457 examples: 0.104


 81%|████████  | 242/300 [2:17:26<33:37, 34.78s/it]

Loss after 2016866 examples: 0.155


 81%|████████▏ | 244/300 [2:18:35<32:14, 34.55s/it]

Loss after 2028884 examples: 0.054


 82%|████████▏ | 245/300 [2:19:09<31:39, 34.54s/it]

Loss after 2041293 examples: 0.112


 82%|████████▏ | 246/300 [2:19:44<31:01, 34.48s/it]

Loss after 2053311 examples: 0.064


 83%|████████▎ | 248/300 [2:20:53<29:52, 34.47s/it]

Loss after 2065720 examples: 0.035


 83%|████████▎ | 249/300 [2:21:27<29:16, 34.43s/it]

Loss after 2078129 examples: 0.044


 84%|████████▎ | 251/300 [2:22:36<28:11, 34.51s/it]

Loss after 2090147 examples: 0.034


 84%|████████▍ | 252/300 [2:23:10<27:31, 34.41s/it]

Loss after 2102556 examples: 0.023


 85%|████████▍ | 254/300 [2:24:20<26:34, 34.67s/it]

Loss after 2114574 examples: 0.034


 85%|████████▌ | 255/300 [2:24:55<26:06, 34.82s/it]

Loss after 2126983 examples: 0.046


 86%|████████▌ | 257/300 [2:26:07<25:23, 35.43s/it]

Loss after 2139001 examples: 0.061


 86%|████████▌ | 258/300 [2:26:42<24:42, 35.30s/it]

Loss after 2151410 examples: 0.030


 87%|████████▋ | 260/300 [2:27:51<23:07, 34.69s/it]

Loss after 2163428 examples: 0.045


 87%|████████▋ | 261/300 [2:28:25<22:31, 34.64s/it]

Loss after 2175837 examples: 0.048


 88%|████████▊ | 263/300 [2:29:34<21:18, 34.56s/it]

Loss after 2187855 examples: 0.068


 88%|████████▊ | 264/300 [2:30:09<20:49, 34.72s/it]

Loss after 2200264 examples: 0.040


 89%|████████▊ | 266/300 [2:31:17<19:27, 34.35s/it]

Loss after 2212282 examples: 0.052


 89%|████████▉ | 267/300 [2:31:52<18:52, 34.31s/it]

Loss after 2224691 examples: 0.041


 90%|████████▉ | 269/300 [2:33:01<17:50, 34.54s/it]

Loss after 2236709 examples: 0.062


 90%|█████████ | 270/300 [2:33:36<17:23, 34.79s/it]

Loss after 2249118 examples: 0.039


 90%|█████████ | 271/300 [2:34:11<16:44, 34.63s/it]

Loss after 2261136 examples: 0.041


 91%|█████████ | 273/300 [2:35:19<15:28, 34.39s/it]

Loss after 2273545 examples: 0.055


 91%|█████████▏| 274/300 [2:35:53<14:52, 34.31s/it]

Loss after 2285954 examples: 0.044


 92%|█████████▏| 276/300 [2:37:02<13:40, 34.17s/it]

Loss after 2297972 examples: 0.025


 92%|█████████▏| 277/300 [2:37:35<12:58, 33.84s/it]

Loss after 2310381 examples: 0.088


 93%|█████████▎| 279/300 [2:38:43<11:57, 34.17s/it]

Loss after 2322399 examples: 0.147


 93%|█████████▎| 280/300 [2:39:17<11:20, 34.02s/it]

Loss after 2334808 examples: 0.232


 94%|█████████▍| 282/300 [2:40:26<10:18, 34.38s/it]

Loss after 2346826 examples: 0.134


 94%|█████████▍| 283/300 [2:41:01<09:44, 34.40s/it]

Loss after 2359235 examples: 0.108


 95%|█████████▌| 285/300 [2:42:09<08:32, 34.19s/it]

Loss after 2371253 examples: 0.249


 95%|█████████▌| 286/300 [2:42:43<08:00, 34.34s/it]

Loss after 2383662 examples: 0.112


 96%|█████████▌| 288/300 [2:43:53<06:53, 34.44s/it]

Loss after 2395680 examples: 0.095


 96%|█████████▋| 289/300 [2:44:27<06:20, 34.56s/it]

Loss after 2408089 examples: 0.091


 97%|█████████▋| 291/300 [2:45:37<05:12, 34.69s/it]

Loss after 2420107 examples: 0.049


 97%|█████████▋| 292/300 [2:46:12<04:37, 34.71s/it]

Loss after 2432516 examples: 0.054


 98%|█████████▊| 294/300 [2:47:21<03:27, 34.58s/it]

Loss after 2444534 examples: 0.037


 98%|█████████▊| 295/300 [2:47:55<02:52, 34.54s/it]

Loss after 2456943 examples: 0.026


 99%|█████████▊| 296/300 [2:48:29<02:17, 34.38s/it]

Loss after 2468961 examples: 0.072


 99%|█████████▉| 298/300 [2:49:40<01:09, 34.78s/it]

Loss after 2481370 examples: 0.041


100%|█████████▉| 299/300 [2:50:14<00:34, 34.71s/it]

Loss after 2493779 examples: 0.036


100%|██████████| 300/300 [2:50:49<00:00, 34.16s/it]


Accuracy of the model on the 2079 test images: 65.56036556036555%


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy,▁▂▂▃▃▃▅▄▄▅▅▆▆▇▇▇▇▇▇▇█▇██████████████████
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▇▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_accuracy,▁

0,1
accuracy,0.99414
epoch,299.0
loss,0.03633
test_accuracy,0.6556


RuntimeError: step!=1 is currently not supported

In [33]:
img = Image.open('./test_samples/cork.png')
tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()])
img_tensor = tf(img)
img_tensor = img_tensor.unsqueeze(0)
print(img_tensor.shape)

torch.Size([1, 3, 224, 224])


In [34]:
output = model(img_tensor)
_, predict = torch.max(output, 1)
score = torch.nn.functional.softmax(output, dim=1).detach().numpy()
img.show()


# print(
#     "This image most likely belongs to {} with a {:.2f} percent confidence."
#         .format(class_names[(np.argmax(score))], 100 * np.max(score))
# )

NameError: name 'model' is not defined

In [40]:
from IPython.display import IFrame

IFrame("https://wandb.ai/mindw96/TextureClassification/2pvdkac1", width="100%", height=720)