In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision

import torch.nn.functional as F
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pandas as pd
import collections
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import time
plt.ion()   # interactive mode

from database_functions.databasereader import DatabaseReader

# Custom Functions
from shoe_dataset import ShoeDataSet
from shoe_dataset import train_transformations, valid_transformations
from feature_maps_extractor import ExtractFeatureMaps
from class_activation_maps import ScoreCam

torch.manual_seed(1994)

ModuleNotFoundError: No module named 'database_functions'

In [None]:
pd.options.display.max_rows = 500

In [None]:
class custom_model:
    
    def __init__(self, model_type,unfreeze_layers, output_dir,freeze_factor = 1 , pretrained = True):
        
        if torch.cuda.is_available():      
            self.device = torch.device("cuda:0")
            print("Model set up on the GPU")
            
        
        else:
            self.device = torch.device("cpu")
            print("Running on the CPU")
        
        
        # Define the backbone model
        if model_type == "resnet101":
            self.model = models.resnet101(pretrained=pretrained)
        
        elif model_type == "resnet50":
            self.model = models.resnet50(pretrained=pretrained)        
        
        else: 
            print("model type unrecognised")
            
        print(f"Model backbone set to: {model_type}")
        
        # Push the model to device
        self.model = self.model.to(self.device)

        self.unfreeze_layers = unfreeze_layers
        self.freeze_factor = freeze_factor
        
        if self.freeze_factor == 1:
            for name,child in self.model.named_children():
                if name not in self.unfreeze_layers:
                    for param in child.parameters():
                        param.requires_grad = False
        
        else:
            for name,child in self.model.named_children():
                if name not in self.unfreeze_layers:
                    for param in child.parameters():
                        param.requires_gra = False
                    
                elif name in self.unfreeze_layers:    
                    for param in list(child.parameters())[int(self.freeze_factor*len(list(child.parameters()))):]:
                        param.requires_grad = False

                        
        # Defines the output dir for the model to be saved
        self.output_dir = output_dir
        
        from torch.utils.tensorboard import SummaryWriter
        self.writer = SummaryWriter(self.output_dir,)
            
    
    def train(self, loss_func, optimizer, lr_scheduler,learning_rate, epochs, trainloader, valloader, eval_period = 100 ):

        self.trainloader = trainloader
        self.valloader = valloader
        self.learning_rate = learning_rate
        self.eval_period = eval_period
        self.epochs = epochs
        
        
        # Params of the 4th layer to confirm later 
        # that layers are correctly frozen
        self._original_weight = list(self.model.layer4.parameters())[0]
        
        
        print(f"Eval Period set to: {self.eval_period}")

        self.criterion = loss_func()
        self.optimizer = optimizer(filter(lambda p: p.requires_grad, self.model.parameters()),lr = self.learning_rate)
        self.lr_scheduler = lr_scheduler(self.optimizer)

        train_epoch_loss = []
        train_epoch_acc  = []

        val_epoch_loss = []
        val_epoch_acc = []

        best_val_acc = 0.0
        running_training_loss = 0.0

        for epoch in range(self.epochs):
            
            # check for first 3 epochs whether the layers
            # are actually frozen and the weights do not change
            
            if 0 < epoch < 3:
                
                print(f"Did the Weights change in epoch: {epoch}?")
                if list(self.model.layer4.parameters())[0].cpu().numpy().all() == self._original_weight[0].cpu().numpy().all():
                    print("No")
                else:
                    print("Yes, abort!!!")
            
            

            train_batch_loss = []
            train_batch_acc = []

            val_batch_loss = []
            val_batch_acc = []

            # train loop
            for idx, data in enumerate(self.trainloader):

                self.model.train()
                
                inputs,labels = data["image"],data["label"]
                labels = np.asarray(labels)
                labels = torch.from_numpy(labels.astype("long"))
#                 print(type(inputs))
#                 print(labels)
                # Place tensors on GPU
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                # Zero out the accumulated gradients
                self.optimizer.zero_grad()

                outputs = self.model(inputs)
                outputs = outputs#.permute(1,0) 
#                 print("outputs shape",outputs.shape)
#                 print("lables shape", labels.shape)
                loss = self.criterion(outputs,labels)#.float())
                
                running_training_loss+=loss.item()

                #append the mean train loss (woking on a batch)
                #use item() to detach from GPU
                train_batch_loss.append(loss.item())

                # Identify Correct predictions
                correct_preds = [torch.argmax(i)==torch.argmax(j) for i,j in zip(outputs,labels)]
                train_acc = correct_preds.count(True)/len(correct_preds)
                train_batch_acc.append(train_acc)


                # Backwards pass
                loss.backward()


                self.optimizer.step()                                                                
                
                # print out stats every N iterations
                if idx%200 == 0:
                    
                    current_lr = self.optimizer.param_groups[0]['lr']
                    print(f"Current lr: {current_lr}")
                    print(f"Step: {idx}/{len(trainloader)}; Epoch: {epoch+1}/{self.epochs}; Train Batch Loss: {loss.item()}")
            
            
            self.writer.add_scalar("Training Loss", running_training_loss/len(self.trainloader), epoch)
            running_training_loss = 0.0
            
            train_epoch_loss.append(torch.tensor(train_batch_loss).mean())
            train_epoch_acc.append(torch.tensor(train_batch_acc).mean())

            running_val_loss = 0

            # Validation loop every self.eval_period epochs
            eval_count = 0
            if epoch%self.eval_period == 0:
                print(f"Starting evaluating at epoch: {epoch}")
                with torch.no_grad():
                    for idx,data in enumerate(self.valloader):

                        # set model in eval() mode
                        self.model.eval()

                        inputs,labels, self.val_file_name = data["image"], data["label"], data["file_name"]
                        labels = np.asarray(labels)
                        labels = torch.from_numpy(labels.astype("long"))
                        # Remove inputs/lables from the GPU
                        inputs = inputs.detach().cpu()
                        labels = labels.detach().cpu()

                        # Predict outputs
                        outputs = self.model(inputs)
                        outputs = outputs  
#                         print("outputs shape", outputs.shape)
#                         print("labels shape", labels.shape)
                        # Obtain and append val_batch_loss
                        val_loss = self.criterion(outputs,labels)
                        val_batch_loss.append(loss.item())
                        running_val_loss += val_loss

                        # Obtain and append val_batch_acc
                        correct_preds = [torch.argmax(i) == torch.argmax(j) for i,j in zip(outputs,labels)]

                        val_acc = correct_preds.count(True)/len(correct_preds)
                        val_batch_acc.append(val_acc)
                        
                        
                        
                        if eval_count%3 == 0:
                            
#                             all_feature_maps = []
#                             extractor = ExtractFeatureMaps(self.model)
#                             print("collecting feature maps")
# #                             for name in data["file_name"][:1]:
# #                                 print("name", name)
# #                                 extracteds_map = extractor.extraction(name)
# #                                 all_feature_maps.append(torch.from_numpy(extracteds_map))
                            
#                             maps = extractor.extraction(data["file_name"][0])
#                             print("len  maps",len(maps))
#                             print("type map", type(maps))
#                             print("dsshape maps", maps.shape)
#                             maps_tp = np.moveaxis(maps, 3,1)
# #                             maps_tens = torch.from_numpy(maps)
# #                             maps_tens = maps_tens.permute(0,3,2,1)
# #                             print("maps_tens shape", maps_tens.shape)
                            
# #                             maps = maps.permute(2,1,0)
# #                             maps_tensor = torch.from_numpy(maps)
# #                             all_feature_maps_stacked = np.vstack(all_feature_maps)
# #                             grid = torchvision.utils.make_grid(all_feature_maps)
# #                             all_feature_maps_stacked = torch.from_numpy(all_feature_maps_stacked)
#                             self.writer.add_images("Feature Maps of every 10th conv layer", maps_tp,global_step=epoch)
                                
                            
                            all_layers_maps = []
                            print("Preparing CAM")
                            
                            for i in range(1,5):
                                print(f"Layer{i}")
                                score_cam = ScoreCam(self.model, f"layer{i}")

                                eval_count+=1

                                top_images = []
                                bottom_images = []
                                images_list = []

                                images = data["image"]
                                names = data["file_name"]

                                for idx, (image,name) in enumerate(zip(images,names),1):
#                                     image = torch.unsqueeze(image,0)
                                    image = torch.unsqueeze(image,0)
                                    no_trans, heatmap_image = score_cam.generate_cam(input_image=image, filename=name)
                                    images_list.append(np.array(heatmap_image))
#                                     if idx <= int(len(data)/2):
#                                         top_images.append(np.array(heatmap_image))
                                        
#                                     else:
#                                         bottom_images.append(np.array(heatmap_image))
                                        
#                                 top_images = np.hstack(top_images)
#                                 print("shape of top im", top_images.shape)
#                                 bottom_images = np.hstack(bottom_images)
#                                 print("shape of bottom", bottom_images.shape)
#                                 all_images = np.vstack((top_images,bottom_images))
                                
                                all_layers_maps.append(images_list)
                            
                            all_layers_maps = np.hstack(all_layers_maps)
#                             print("all layers maps shaspe", all_layers_maps.shape)
#                             print("all layers maps type", type(all_layers_maps))
                            all_layers_maps = np.moveaxis(all_layers_maps, -1,1)
                            all_layers_maps = np.moveaxis(all_layers_maps, -1,-2)
#                             print("\n all_layers_maps[:,:,0,0]", all_layers_maps[:,:,0,0],"\n")
#                             print("\n all_layers_maps[:,:,1,1]",all_layers_maps[:,:,1,1],"\n")
#                             print("\n all_layers_maps[:,:,2,2]",all_layers_maps[:,:,2,2],"\n")
#                             print("\n all_layers_maps[:,0,:,:]",all_layers_maps[:,0,:,:],"\n")
#                             print("\n all_layers_maps[:,1,:,:]",all_layers_maps[:,1,:,:],"\n")
#                             print("\n all_layers_maps[:,2,:,:]",all_layers_maps[:,2,:,:],"\n")
#                             print("\n all_layers_maps[:,3,:,:]",all_layers_maps[:,3,:,:].shape,"\n")
#                             print("any != 255", all_layers_maps[:,3,:,:].any() != 255)
#                             print("all == 255", all_layers_maps[:,3,:,:].all()==255)
                            all_layers_maps = all_layers_maps[:,:3,:,:]
#                             print("shape after slicing", all_layers_maps.shape)
                            
#                             print("\n all layers maps shaspe", all_layers_maps.shape)

                            self.writer.add_images("Class Activation Maps, Layers 1-4", torch.from_numpy(all_layers_maps), global_step = epoch)
                                
#                             print("******* \n\n\n COMPLETED CAM KURWA\n\n\n **********")
                                
                                
                            
                            
                            

                    val_epoch_loss.append(torch.tensor(val_batch_loss).mean())
                    val_epoch_acc.append(torch.tensor(val_batch_acc).mean())
                    
                    
                    # Visualise the predictions on the last validation batch
                    
#                     print(f"len inputs of last batch: {len(inputs)}")
                    
                    self.writer.add_scalar("Validation Loss", 
                                           running_val_loss/len(self.valloader),
                                           global_step = epoch
                                          )
                    
                    self.writer.add_figure("Predictions vs. GT",
                                           self.plot_classes_preds(inputs, labels),
                                           global_step = epoch
                                          )
                    
                    running_val_loss = 0.0
                    
                    
            # Save the model which yielding best acc
            if val_acc > best_val_acc:
                print(f"Saving model at epoch: {epoch}")
                best_val_acc = val_acc
                self.best_model_wts = copy.deepcopy(self.model.state_dict())



            # Print out 

            self.lr_scheduler.step(val_acc)


        self.model = self.model.load_state_dict(self.model.state_dict())
    
    def matplotlib_imshow(self, img, one_channel=False):
        if one_channel:
            img = img.mean(dim=0)
        img = img / 2 + 0.5     # unnormalize
        npimg = img.cpu().numpy()
        if one_channel:
            plt.imshow(npimg, cmap="Greys")
        else:
            plt.imshow(np.transpose(npimg, (1, 2, 0)))
            
    def images_to_probs(self, images):
        '''
        Generates predictions and corresponding probabilities from a trained
        network and a list of images
        '''
        output = self.model(images)
        # convert output probabilities to predicted class
        _, preds_tensor = torch.max(output, 1)
        preds = np.squeeze(preds_tensor.cpu().numpy())
        return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]


    def plot_classes_preds(self, images, labels):
        '''
        Generates matplotlib Figure using a trained network, along with images
        and labels from a batch, that shows the network's top prediction along
        with its probability, alongside the actual label, coloring this
        information based on whether the prediction was correct or not.
        Uses the "images_to_probs" function.
        '''
        labels = labels.detach().cpu()
        with torch.no_grad():
            preds, probs = self.images_to_probs(images)
            # plot the images in the batch, along with predicted and true labels
            fig = plt.figure(figsize=(48, 15))
            for idx in np.arange(4):
                print(f"pred idx: {preds[idx]}")
                ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
                self.matplotlib_imshow(images[idx], one_channel=False)
                ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
                    classes[preds[idx]],
                    probs[idx] * 100.0,
                    classes[labels[idx].item()]),
                            color=("green" if preds[idx]==labels[idx].item() else "red"), fontsize=40)
            return fig



In [4]:
"""

Import all data from the database

"""


db_reader = DatabaseReader("crepcheque")

get_images_query ="""

                    SELECT
                        r.crep_id,
                        r.raw_brand_text,
                        i.image_id,
                        i.image_type,
                        i.image_file_path
                    FROM raw_creps r
                    LEFT JOIN
                        images i
                        ON r.crep_id = i.crep_id
                    WHERE 
                        r.images IS NOT NULL
                        AND r.images_processed = true
                        AND i.image_downloaded = true
                    """

database_df = db_reader.send_query(query=get_images_query, return_as_df=True)

In [5]:
database_df.shape

(328852, 5)

In [6]:
brand_counts = pd.DataFrame(database_df.raw_brand_text.value_counts())
brand_counts.reset_index(inplace=True)
brand_counts.columns = ['brand', 'count']
brands_to_keep = brand_counts[brand_counts['count'] >= 10].brand.tolist()

In [7]:
filtered_db_df = database_df[database_df.raw_brand_text.isin(brands_to_keep)]
filtered_db_df.reset_index(inplace=True, drop=True)

In [8]:
def sample_db_data(brand, db_df, n=3000):
    sampled_df = db_df[db_df.raw_brand_text == brand]
    #print(sampled_df.shape)
    shuffled_sampled_df = sklearn.utils.shuffle(sampled_df)
    shuffled_sampled_df.reset_index(inplace=True, drop=True)
    
    if n > shuffled_sampled_df.shape[0]:
        final_sampled = shuffled_sampled_df
    else:
        final_sampled = shuffled_sampled_df.loc[0:n, :]
    
    return final_sampled

In [9]:
brand = 'Gucci'

In [10]:
sampled_df_list = []
for brand in brands_to_keep:
    sampled_df_list.append(sample_db_data(brand=brand, db_df=filtered_db_df, n=10))

In [11]:
downsampled_filtered_df = pd.concat(sampled_df_list,axis=0,ignore_index=True)

In [12]:
label_encoder = LabelEncoder()
label_encoder.fit(downsampled_filtered_df.raw_brand_text.tolist())

LabelEncoder()

In [13]:
downsampled_filtered_df["encoded_label"] = label_encoder.transform(downsampled_filtered_df.raw_brand_text.tolist())

In [14]:
downsampled_filtered_df.head()

Unnamed: 0,crep_id,raw_brand_text,image_id,image_type,image_file_path,encoded_label
0,17233,Nike,175179,Additional,/home/max/hdd-data1/images/original/17233-1751...,38
1,14465,Nike,147997,Additional,/home/max/hdd-data1/images/original/14465-1479...,38
2,5555,Nike,56005,Additional,/home/max/hdd-data1/images/original/5555-56005...,38
3,9046,Nike,92204,Additional,/home/max/hdd-data1/images/original/9046-92204...,38
4,9087,Nike,92659,Additional,/home/max/hdd-data1/images/original/9087-92659...,38


In [15]:
labels_encoding_dict = dict(list(zip(downsampled_filtered_df["encoded_label"].tolist(),downsampled_filtered_df["raw_brand_text"].tolist())))
labels_encoding_dict


{38: 'Nike',
 26: 'Jordan',
 63: 'adidas',
 60: 'Vans',
 37: 'New Balance',
 14: 'Converse',
 0: 'ASICS',
 48: 'Reebok',
 47: 'Puma',
 5: 'Balenciaga',
 59: 'Under Armour',
 53: 'Saucony',
 57: 'Timberland',
 17: 'Dior',
 4: 'BAPE',
 16: 'Diadora',
 22: 'Gucci',
 39: 'OFF-WHITE',
 12: 'Clarks',
 33: 'Louis Vuitton',
 62: 'Yeezy',
 32: 'Li-Ning',
 27: 'K-Swiss',
 18: 'Dr. Martens',
 19: 'Ewing Athletics',
 20: 'FEAR OF GOD',
 11: 'Chanel',
 15: 'DC Shoes',
 61: 'Versace',
 21: 'Fila',
 23: 'Hoka One One',
 54: 'Sonra',
 28: 'KangaROOS',
 7: 'Birkenstock',
 55: 'Suicoke',
 58: 'Tommy Hilfiger',
 3: 'Asics',
 40: 'Off-White',
 13: 'Common Projects',
 29: 'Karhu',
 45: 'Prada',
 41: 'Onitsuka Tiger',
 1: 'Alexander McQueen',
 24: 'Hummel',
 56: 'Supra',
 6: 'Big Baller Brand',
 34: 'Mephisto',
 44: 'Pizza Hut',
 25: 'Ice Cream',
 2: 'Anta',
 30: 'Lakai',
 43: 'Palace',
 42: 'Osiris',
 49: 'Revenge X Storm',
 50: 'Rhude',
 8: 'Brandblack',
 10: 'Burberry',
 52: 'Salomon',
 51: 'Saint Lauren

In [16]:
labels_encoding_dict_sorted = collections.OrderedDict(sorted(labels_encoding_dict.items()))
labels_encoding_dict_sorted

OrderedDict([(0, 'ASICS'),
             (1, 'Alexander McQueen'),
             (2, 'Anta'),
             (3, 'Asics'),
             (4, 'BAPE'),
             (5, 'Balenciaga'),
             (6, 'Big Baller Brand'),
             (7, 'Birkenstock'),
             (8, 'Brandblack'),
             (9, 'Brooks'),
             (10, 'Burberry'),
             (11, 'Chanel'),
             (12, 'Clarks'),
             (13, 'Common Projects'),
             (14, 'Converse'),
             (15, 'DC Shoes'),
             (16, 'Diadora'),
             (17, 'Dior'),
             (18, 'Dr. Martens'),
             (19, 'Ewing Athletics'),
             (20, 'FEAR OF GOD'),
             (21, 'Fila'),
             (22, 'Gucci'),
             (23, 'Hoka One One'),
             (24, 'Hummel'),
             (25, 'Ice Cream'),
             (26, 'Jordan'),
             (27, 'K-Swiss'),
             (28, 'KangaROOS'),
             (29, 'Karhu'),
             (30, 'Lakai'),
             (31, 'Le Coq Sportif'),
     

In [17]:
classes = labels_encoding_dict_sorted

In [18]:
# classes[0]

In [19]:
"""

Shuffle the dataframe and split into train and validation

""";

# shuffled_df = sklearn.utils.shuffle(database_df)

# train_df = shuffled_df.iloc[:int(database_df.shape[0]*0.8), :]
# val_df = shuffled_df.iloc[int(database_df.shape[0]*0.8):, :]

X = downsampled_filtered_df.image_file_path#.tolist()
y = downsampled_filtered_df.encoded_label#.tolist()

X_train, X_val, y_train, y_val = train_test_split(X,y, test_size = 0.2 ,stratify = y, 
                                                    random_state = 1994)

train_df = pd.concat([X_train, y_train],axis=1)
val_df = pd.concat([X_val,y_val], axis=1)



In [20]:
"""

Define the image transformatoins for the train and validation sets
Define 

"""
transforms_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomPerspective(p = 0.4), # randomly change img perspective
    torchvision.transforms.RandomHorizontalFlip(p = 0.2),
    torchvision.transforms.Resize((300,300)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5,),(0.5,0.5,0.5)),
])
transforms_valid = torchvision.transforms.Compose([
    torchvision.transforms.Resize((300,300)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5,),(0.5,0.5,0.5))
])

train_set = ShoeDataSet(train_df, transform=transforms_train)
val_set = ShoeDataSet(val_df, transform=transforms_valid)


train_loader = DataLoader(train_set, shuffle=True, batch_size=4,drop_last=True)
val_loader = DataLoader(val_set, shuffle=True, batch_size=4, drop_last=True)


In [21]:
"""

Add layers which are to be unfrozen for finetuning purposes

Options: fc (bare minimum), layer1-4 (conv2d layer bottlenecks)

"""

custom_model = custom_model("resnet101",["fc"], output_dir="../runs/")

Model set up on the GPU
Model backbone set to: resnet101


In [22]:
"""

Redefine the FC layer of your model so that it matches 
the number of classes present in the dataset

Ensure that it is placed on the GPU - .cuda()

"""

custom_model.model.fc = torch.nn.Linear(2048, len(set(y))).cuda()

# Define the output directory for the logs
custom_model.output_dir = "../runs/"



In [23]:
custom_model.output_dir

'../runs/'

In [24]:
"""

Define the loss functoin (criterion)
OPtimizer - Adam
LRScheduler

"""

criterion = torch.nn.CrossEntropyLoss
optimizer = torch.optim.Adam
learning_rate = 0.001
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau

In [25]:
"""

Initialise the training of the model


"""
# small eval period for debugging purposes

custom_model.train(loss_func = criterion, optimizer = optimizer, lr_scheduler = lr_scheduler, 
           learning_rate = learning_rate, epochs = 100, trainloader = train_loader, 
           valloader = val_loader, eval_period = 10)


Eval Period set to: 10
Current lr: 0.001
Step: 0/142; Epoch: 1/100; Train Batch Loss: 4.212957859039307


  "Palette images with Transparency expressed in bytes should be "


Starting evaluating at epoch: 0
Preparing CAM
Layer1
Layer2
Layer3
Layer4
pred idx: 63
pred idx: 46
pred idx: 46
pred idx: 46
Did the Weights change in epoch: 1?
No
Current lr: 0.001
Step: 0/142; Epoch: 2/100; Train Batch Loss: 3.9643640518188477
Did the Weights change in epoch: 2?
No
Current lr: 0.001
Step: 0/142; Epoch: 3/100; Train Batch Loss: 4.322152137756348
Current lr: 0.001
Step: 0/142; Epoch: 4/100; Train Batch Loss: 4.403522491455078
Current lr: 0.001
Step: 0/142; Epoch: 5/100; Train Batch Loss: 3.9957828521728516
Current lr: 0.001
Step: 0/142; Epoch: 6/100; Train Batch Loss: 4.210690498352051
Current lr: 0.001
Step: 0/142; Epoch: 7/100; Train Batch Loss: 4.2820515632629395
Current lr: 0.001
Step: 0/142; Epoch: 8/100; Train Batch Loss: 4.463751792907715
Current lr: 0.001
Step: 0/142; Epoch: 9/100; Train Batch Loss: 4.358029365539551
Current lr: 0.001
Step: 0/142; Epoch: 10/100; Train Batch Loss: 4.182094573974609
Current lr: 0.001
Step: 0/142; Epoch: 11/100; Train Batch Loss:

KeyboardInterrupt: 

In [None]:
torchvision.transforms.ToTensor()

In [8]:
trmf = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
#     torchvision.transforms.RandomRotation(90)
])

In [4]:
from PIL import Image


In [5]:
pil_image = Image.open("/Users/michalbarrington/Downloads/snake.jpg")

In [6]:
img_tens = trmf(pil_image)

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>