In [63]:
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json, glob

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F
import sys
from lime import lime_image
from skimage.segmentation import mark_boundaries


In [78]:
# resize and take the center part of image to what our model expects
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    
    img2 = transf(img).unsqueeze(0)
    # unsqeeze converts single image to batch of 1
    
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model.to(device)
#     img = img.to(device)
    
    return img2

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

def batch_predict(images):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 

In [81]:
PATH = r"I:\Research\House\Dataset\resnet84.pth"
feature_extract = True
num_classes = 4

batch_size=32
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            
model_ft = models.resnet101(pretrained=False)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
model_ft.load_state_dict(torch.load(PATH))
 
# model_ft.cuda()

model = model_ft 
model.cuda()
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [123]:

fi = files[0]

def lime_explain(fi):

    filename = os.path.splitext(os.path.basename(fi))[0]

    print(filename)
        
    img = get_image(fi)
    img_t = get_input_tensors(img)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    img_t = img_t.to(device)
#     batch = batch.to(device)
    
    logits = model(img_t)
    _, pred = torch.max(logits, 1)
#     pred = pred.item[0]


    
    print("logits:", logits)
#     print("pred:", pred.cpu().int())
    pred = pred.cpu().item()
    
    print("pred:", pred)
    



    pill_transf = get_pil_transform()
    preprocess_transform = get_preprocess_transform()




    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(np.array(pill_transf(img)), 
                                             batch_predict,
                                             top_labels=5, 
                                             hide_color=0, 
                                             num_samples=1000)


    

    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
    img_boundry1 = mark_boundaries(temp/255.0, mask)
    plt.imsave(saved_path + "Pred=" + str(pred) + "_" +filename + '.png', img_boundry1)

    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
    img_boundry2 = mark_boundaries(temp/255.0, mask)
    plt.imsave(saved_path + filename + 'redgreen.png', img_boundry2)
    
    del img
    del img_t
    del logits


In [129]:
folder = r'X:\Shared drives\Group_research\Story_research\Datasets\LimeResult\LimeResult\test_img\2\*.jpg'
saved_path = r"X:\Shared drives\Group_research\Story_research\Datasets\LimeResult\LimeResult\test_img\2\results\\"
# saved_path = saved_path.replace(r"\\", '\\')
                                
files = glob.glob(folder)
print(len(files))


for idx, f in enumerate(files[:]):
    print("Processing: ", idx, "/", len(files))
    lime_explain(f)
    

65
Processing:  0 / 65
2129
device: cuda
logits: tensor([[-3.4576, -5.0386, -8.0023, 18.3002]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  1 / 65
0_Kq4gGNsLehZ2eUkvXyWmdg_-76.583777_39.281887_0_40.52
device: cuda
logits: tensor([[ 2.4643,  0.0746, -5.1131,  3.0677]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  2 / 65
0_lyMAbzmE1uLNehJ4H_zk6g_-76.617097_39.279067_0_207.24
device: cuda
logits: tensor([[ 2.3298, -3.7812,  2.1722, -0.8420]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  3 / 65
0_NSy93SUY-AmGISYlBLHDfw_-76.685004_39.284277_0_193.05
device: cuda
logits: tensor([[-0.9942, -1.2368, -0.9195,  3.7783]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  4 / 65
0_p-WJU0uvSmguyAocsE_J6A_-76.672609_39.313442_0_265.40
device: cuda
logits: tensor([[-1.5613, -4.3093,  1.5574,  4.7607]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  5 / 65
0_Q9vWGo3uBLu3JxrzGG2ciQ_-76.586119_39.289968_0_105.21
device: cuda
logits: tensor([[ 1.8631, -6.3897,  1.2948,  3.7366]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  6 / 65
0_qv8ISAA1Z7umXcUyqqUV6w_-76.582014_39.280716_0_355.34
device: cuda
logits: tensor([[ 3.3875, -0.4747, -4.1620,  1.6015]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  7 / 65
0_RgCibHi8wxQUzS3JVgr64Q_-76.58327_39.281577_0_53.80
device: cuda
logits: tensor([[ 0.7898, -2.0212, -0.1632,  1.6263]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  8 / 65
0_sE9A3IRNb9ANmP4F-5N0GA_-76.610742_39.308476_0_276.55
device: cuda
logits: tensor([[-1.8142, -8.1005,  7.3739,  2.8855]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 2


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  9 / 65
0_ThntmhSsUs2TKjBdoJALxA_-76.615627_39.313273_0_24.56
device: cuda
logits: tensor([[ 4.7010, -6.7540,  4.1797, -2.4164]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  10 / 65
0_uaEXPLJLokcBbJPyqFD-Uw_-76.573696_39.282289_0_347.96
device: cuda
logits: tensor([[-0.0998, -7.8294,  7.3049,  0.7929]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 2


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  11 / 65
0_W1nVk_JGwQWWRXR71DMQRg_-76.650733_39.288001_0_358.58
device: cuda
logits: tensor([[ 2.6106, -9.0573,  4.1891,  2.5451]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 2


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  12 / 65
0_wddv377Jtgt1hHddiTpaWQ_-76.594839_39.293529_0_294.41
device: cuda
logits: tensor([[ 1.0509, -2.5593, -3.3631,  5.3569]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  13 / 65
0_y-u5esOU3IhQqwxZ3SCw1A_-76.627489_39.30168_0_263.61
device: cuda
logits: tensor([[ 1.4470, -4.7960, -0.1821,  3.9728]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  14 / 65
0_zfh4sm9Q0zJNiC9RENl7hg_-76.666511_39.296756_0_211.08
device: cuda
logits: tensor([[-2.2470, -6.5719, -0.9249, 10.7687]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  15 / 65
2060
device: cuda
logits: tensor([[ 4.3553, -1.5716, -4.1476,  1.6179]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  16 / 65
2062
device: cuda
logits: tensor([[ 3.0848, -4.0444, -4.0511,  5.6239]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  17 / 65
2063
device: cuda
logits: tensor([[ 2.8075, -3.9347, -1.9003,  3.4008]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  18 / 65
2064
device: cuda
logits: tensor([[ 2.5846, -0.8129, -4.4500,  3.1818]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  19 / 65
2065
device: cuda
logits: tensor([[-6.2447,  0.5584, -5.2293, 12.1296]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  20 / 65
2067
device: cuda
logits: tensor([[-3.0926, -8.9256,  4.0261,  8.2745]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  21 / 65
2069
device: cuda
logits: tensor([[ 1.1494, -4.4845, -0.6801,  4.4339]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  22 / 65
2070
device: cuda
logits: tensor([[-1.0802,  0.3051, -4.6665,  6.2907]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  23 / 65
2072
device: cuda
logits: tensor([[  5.3125,  -2.5750, -10.1644,   8.5120]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  24 / 65
2073
device: cuda
logits: tensor([[ 3.9673, -6.4080, -1.4729,  4.4210]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  25 / 65
2074
device: cuda
logits: tensor([[-0.8755, -5.5858,  5.3505,  1.2548]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 2


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  26 / 65
2075
device: cuda
logits: tensor([[-4.2727, -0.7962, -4.2450, 10.1022]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  27 / 65
2076
device: cuda
logits: tensor([[-0.0344, -1.6006, -8.2216, 10.9123]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  28 / 65
2079
device: cuda
logits: tensor([[  2.4176,  -1.3311, -10.5334,  10.6271]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  29 / 65
2080
device: cuda
logits: tensor([[-1.3608,  0.9129, -6.9199,  8.4482]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  30 / 65
2081
device: cuda
logits: tensor([[  3.2729,  -3.3534, -10.2237,  11.6009]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  31 / 65
2082
device: cuda
logits: tensor([[ 4.7301,  0.0882, -8.9639,  4.7222]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  32 / 65
2083
device: cuda
logits: tensor([[-3.3020,  2.3516, -9.9506, 12.3292]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  33 / 65
2085
device: cuda
logits: tensor([[-8.9501,  0.3619,  3.0506,  6.1795]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  34 / 65
2086
device: cuda
logits: tensor([[  0.7593,  -2.0269, -16.3675,  19.7355]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  35 / 65
2087
device: cuda
logits: tensor([[ -2.0662,   2.9207, -11.0164,  11.6626]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  36 / 65
2088
device: cuda
logits: tensor([[ 0.5085, -0.2547, -7.5605,  8.1424]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  37 / 65
2090
device: cuda
logits: tensor([[-2.8089, -4.9765, -5.8970, 15.0203]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  38 / 65
2091
device: cuda
logits: tensor([[-4.3172, -6.3223,  1.0706, 10.3575]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  39 / 65
2093
device: cuda
logits: tensor([[-3.8131, -1.0833, -7.7551, 14.1368]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  40 / 65
2094
device: cuda
logits: tensor([[-5.8062, -1.2325, -6.6457, 15.2279]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  41 / 65
2095
device: cuda
logits: tensor([[-2.8103,  4.2709, -9.9725,  9.6604]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  42 / 65
2096
device: cuda
logits: tensor([[ 3.4266, -3.0932, -6.6485,  7.1443]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  43 / 65
2097
device: cuda
logits: tensor([[-3.8150,  0.0571,  0.4319,  3.8023]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  44 / 65
2099
device: cuda
logits: tensor([[ 0.1863, -2.2568, -5.1250,  8.3662]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  45 / 65
2106
device: cuda
logits: tensor([[-0.3563,  4.0970, -6.2736,  3.2342]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 1


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  46 / 65
2108
device: cuda
logits: tensor([[-0.0270,  0.5974, -5.7305,  5.6093]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  47 / 65
2109
device: cuda
logits: tensor([[ 3.6774, -3.8894, -8.8612,  9.9591]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  48 / 65
2110
device: cuda
logits: tensor([[ -1.0826,   5.5538, -13.7108,  10.5980]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  49 / 65
2111
device: cuda
logits: tensor([[-0.5746, -0.3264, -8.5290, 10.6978]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  50 / 65
2112
device: cuda
logits: tensor([[-3.3482e+00, -8.5269e-04, -1.5497e+00,  5.4535e+00]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  51 / 65
2113
device: cuda
logits: tensor([[ 1.6547, -6.4936, -1.1027,  6.4333]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  52 / 65
2115
device: cuda
logits: tensor([[-2.9067, -2.8384, -3.6090, 10.2479]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  53 / 65
2116
device: cuda
logits: tensor([[ 5.4348, -2.5466, -8.9637,  6.9268]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  54 / 65
2117
device: cuda
logits: tensor([[-3.7676, -2.1277, -6.0872, 13.3549]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  55 / 65
2118
device: cuda
logits: tensor([[-3.3777, -4.0794, -6.1125, 15.1906]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  56 / 65
2119
device: cuda
logits: tensor([[-4.3768, -2.5707, -2.9418, 11.1448]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  57 / 65
2120
device: cuda
logits: tensor([[-1.8050, -4.8262, -6.7822, 14.8277]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  58 / 65
2121
device: cuda
logits: tensor([[-3.5630, -3.9199, -7.2535, 16.4780]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  59 / 65
2122
device: cuda
logits: tensor([[ -1.9757,  -2.3549, -10.7477,  16.8796]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  60 / 65
2123
device: cuda
logits: tensor([[-1.7186, -3.1052, -8.8133, 15.3906]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  61 / 65
2124
device: cuda
logits: tensor([[ 8.4540, -7.0361, -8.3713,  8.1233]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 0


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  62 / 65
2126
device: cuda
logits: tensor([[-0.4384, -5.9283, -7.5944, 15.5726]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  63 / 65
2127
device: cuda
logits: tensor([[-6.4906, -2.5022, -5.2663, 15.8067]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Processing:  64 / 65
2128
device: cuda
logits: tensor([[-1.5300, -5.2207, -5.1962, 12.8762]], device='cuda:0',
       grad_fn=<AddmmBackward>)
pred: 3


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


