<a href="https://colab.research.google.com/github/gekoramy/uni.deep-learning/blob/finetune-like-you-pretrain/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%shell
tee requirements.txt << END
ftfy
jaxtyping
jupyter
matplotlib
pydantic
regex
torch
torchinfo
torchvision
tqdm
ultralytics
END

pip install -q -r requirements.txt
pip install -q git+https://github.com/openai/CLIP.git

ftfy
jaxtyping
jupyter
matplotlib
pydantic
regex
torch
torchinfo
torchvision
tqdm
ultralytics
  Preparing metadata (setup.py) ... [?25l[?25hdone




In [2]:
import clip
import json
import os
import pickle
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import PIL
import itertools as it
import math

from datetime import datetime
from jaxtyping import Float, UInt, Int
from pydantic.dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.utils import draw_bounding_boxes
from torchvision.io import read_image
from torchinfo import summary
from typing import Literal, Callable, Mapping, TypeVar
from tqdm import tqdm
from timeit import default_timer as timer
from torch.utils.tensorboard import SummaryWriter

In [3]:
device: Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
device

'cpu'

#### Utils

In [4]:
def print_train_time(start: float, end: float, device: torch.device = None):
    """Prints difference between start and end time.

    Args:
        start (float): Start time of computation (preferred in timeit format).
        end (float): End time of computation.
        device ([type], optional): Device that compute is running on. Defaults to None.

    Returns:
        float: time between start and end in seconds (higher is longer).
    """
    total_time = end - start
    print(f"Train time on {device}: {total_time:.3f} seconds")
    return total_time

#### Dataset and type declaration

In [5]:
%%shell
if ! [ -d dataset ]; then
  mkdir dataset &&
  gdown 1P8a1g76lDJ8cMIXjNDdboaRR5-HsVmUb &&
  tar -xf refcocog.tar.gz -C dataset &&
  rm refcocog.tar.gz
fi



In [6]:
root = os.path.join("dataset", "refcocog", "")
data_instances = os.path.join(root, "annotations", "instances.json")
data_refs = os.path.join(root, "annotations", "refs(umd).p")
data_images = os.path.join(root, "images", "")

In [7]:
I = TypeVar("I")
P = TypeVar("P")
B = TypeVar("B")
T = TypeVar("T")

Img = UInt[torch.Tensor, "C W H"]
BBox = UInt[torch.Tensor, "4"]
Split = Literal["train", "test", "val"]


@dataclass
class Info:
    description: str  # This is stable 1.0 version of the 2014 MS COCO dataset.
    url: str  # http://mscoco.org/
    version: str  # 1.0
    year: int  # 2014
    contributor: str  # Microsoft COCO group
    date_created: datetime  # 2015-01-27 09:11:52.357475


@dataclass
class Image:
    license: int  # each image has an associated licence id
    file_name: str  # file name of the image
    coco_url: str  # example http://mscoco.org/images/131074
    height: int
    width: int
    flickr_url: str  # example http://farm9.staticflickr.com/8308/7908210548_33e
    id: int  # id of the imag
    date_captured: datetime  # example '2013-11-21 01:03:06'


@dataclass
class License:
    url: str  # example http://creativecommons.org/licenses/by-nc-sa/2.0/
    id: int  # id of the licence
    name: str  # example 'Attribution-NonCommercial-ShareAlike License


@dataclass
class Annotation:
    # segmentation: list[list[float]]  # description of the mask; example [[44.17, 217.83, 36.21, 219.37, 33.64, 214.49, 31.08, 204.74, 36.47, 202.68, 44.17, 203.2]]
    area: float  # number of pixel of the described object
    iscrowd: Literal[
        1, 0
    ]  # Crowd annotations (iscrowd=1) are used to label large groups of objects (e.g. a crowd of people)
    image_id: int  # id of the target image
    bbox: tuple[
        float, float, float, float
    ]  # bounding box coordinates [xmin, ymin, width, height]
    category_id: int
    id: int  # annotation id


@dataclass
class Category:
    supercategory: str  # example 'vehicle'
    id: int  # category id
    name: str  # example 'airplane'


@dataclass
class Instances:
    info: Info
    images: list[Image]
    licenses: list[License]
    annotations: list[Annotation]
    categories: list[Category]


@dataclass
class Sentence:
    tokens: list[str]  # tokenized version of referring expression
    raw: str  # unprocessed referring expression
    sent: str  # referring expression with mild processing, lower case, spell correction, etc.
    sent_id: int  # unique referring expression id


@dataclass
class Ref:
    image_id: int  # unique image id
    split: Split
    sentences: list[Sentence]
    file_name: str  # file name of image relative to img_root
    category_id: int  # object category label
    ann_id: int  # id of object annotation in instance.json
    sent_ids: list[int]  # same ids as nested sentences[...][sent_id]
    ref_id: int  # unique id for refering expression

In [8]:
def fix_ref(x: Ref) -> Ref:
    x.file_name = fix_filename(x.file_name)
    return x


def fix_filename(x: str) -> str:
    """
    :param x: COCO_..._[image_id]_[annotation_id].jpg
    :return:  COCO_..._[image_id].jpg

    >>> fix_filename('COCO_..._[image_id]_0000000001.jpg')
    'COCO_..._[image_id].jpg'

    """
    return re.sub("_\d+\.jpg$", ".jpg", x)

In [9]:
with open(data_refs, "rb") as f:
    raw = pickle.load(f)

refs: list[Ref] = [fix_ref(Ref(**ref)) for ref in raw]

In [10]:
with open(data_instances, "r") as f:
    raw = json.load(f)

instances: Instances = Instances(**raw)

id2annotation: Mapping[int, Annotation] = {x.id: x for x in instances.annotations}

In [11]:
class CocoDataset(Dataset[tuple[PIL.Image, list[str], Float[torch.Tensor, "4"]]]):
    def __init__(
        self,
        split: Split,
        limit: int = -1,
    ):
        self.__init__
        self.items: list[tuple[str, list[str], Float[torch.Tensor, "4"]]] = [
            (i, [s.sent for s in ss], xywh)
            for ref in refs
            if ref.split == split
            for i in [os.path.join(data_images, ref.file_name)]
            for ss in [ref.sentences]
            for xywh in [torch.tensor(id2annotation[ref.ann_id].bbox, dtype=torch.float)]
        ]
        self.len: int = len(self.items) if limit < 0 else min(limit, len(self.items))

    def __len__(self) -> int:
        return self.len

    def __getitem__(
        self, index: int
    ) -> tuple[PIL.Image, list[str], Float[torch.Tensor, "4"]]:
        i, ps, xywh = self.items[index]
        xyxy: Float[torch.Tensor, "4"] = torchvision.ops.box_convert(xywh, in_fmt="xywh", out_fmt="xyxy")
        with PIL.Image.open(i) as img:
            img.load()
            return img, ps, xyxy

In [12]:
class Coco4CLIPDataset(Dataset[tuple[list[PIL.Image], list[str]]]):
    def __init__(
        self,
        split: Split,
        limit: int = -1,
    ):
        self.__init__
        self.items: list[tuple[str, list[str], Float[torch.Tensor, "4"]]] = [
            (i, [s.sent for s in ss], xywh)
            for ref in refs
            if ref.split == split
            for i in [os.path.join(data_images, ref.file_name)]
            for ss in [ref.sentences]
            for xywh in [torch.tensor(id2annotation[ref.ann_id].bbox, dtype=torch.float)]
        ]
        self.len: int = len(self.items) if limit < 0 else min(limit, len(self.items))

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, index: int) -> tuple[list[PIL.Image], list[str]]:
        i, ps, xywh = self.items[index]
        xyxy: Float[torch.Tensor, "4"] = torchvision.ops.box_convert(xywh, in_fmt="xywh", out_fmt="xyxy")
        with PIL.Image.open(i) as img:
            img.load()
            return [img.crop(xyxy.tolist())], ps

In [13]:
def unzip(batch: list[tuple[T, ...]]) -> tuple[list[T], ...]:
    return tuple(zip(*batch))

In [14]:
batch_size: int = 3
limit: int = 5 * batch_size

In [15]:
dl: DataLoader[tuple[list[PIL.Image], list[list[str]], list[Float[torch.Tensor, "4"]]]] = DataLoader(
    dataset=CocoDataset(split="test", limit=limit),
    batch_size=batch_size,
    collate_fn=unzip,
)

In [16]:
dl4clip: DataLoader[tuple[list[PIL.Image], list[str]]] = DataLoader(
    dataset=Coco4CLIPDataset(split="test", limit=limit),
    batch_size=batch_size,
    collate_fn=unzip,
    shuffle=True,
)

In [17]:
imgs: tuple[PIL.Image, ...]
promptss: tuple[list[str], ...]
true_xyxy: tuple[Float[torch.Tensor, "4"], ...]

for imgs, promptss, true_xyxy in dl:
    print(imgs)
    print(promptss)
    print(true_xyxy)
    print("-" * 50)

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x376 at 0x78130528A7D0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x431 at 0x7813050F53C0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x426 at 0x7813050F5420>)
(['the man in yellow coat', 'skiier in red pants'], ['there is red colored truck in between the other trucks', 'a shiny red vintage pickup truck'], ['a apple desktop computer', 'the white imac computer that is also turned on'])
(tensor([374.3100,  65.0600, 510.3500, 267.0000]), tensor([ 93.9500,  83.2900, 598.5600, 373.8600]), tensor([338.8000,  82.1900, 486.1400, 239.5600]))
--------------------------------------------------
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x7813050F5720>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x275 at 0x7813050F5780>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375 at 0x7813050F5900>)
(['a girl wearing glasses and a pink shirt', 'an asian girl with a pin

In [18]:
cropss: tuple[list[PIL.Image], ...]
promptss: tuple[list[str], ...]

for cropss, promptss in dl4clip:
    print(cropss)
    print(promptss)
    print("-" * 50)

([<PIL.Image.Image image mode=RGB size=185x213 at 0x78130528AAD0>], [<PIL.Image.Image image mode=RGB size=100x156 at 0x78130528B070>], [<PIL.Image.Image image mode=RGB size=582x257 at 0x781305289930>])
(['a brown bear near a soda bottle', 'a without hairy brown color teddy bear'], ['a lady in blue t - shirt and white shorts sitting on a park bench', 'a couple of friends are sitting on a bench and hanging out'], ['a table with pizza , drinks , and seasonings on it', 'a table of food , with plates , a pizza , pitchers , and glasses'])
--------------------------------------------------
([<PIL.Image.Image image mode=RGB size=505x291 at 0x7813050F5D20>], [<PIL.Image.Image image mode=RGB size=62x178 at 0x7813050F5900>], [<PIL.Image.Image image mode=RGB size=83x169 at 0x7813050F59F0>])
(['there is red colored truck in between the other trucks', 'a shiny red vintage pickup truck'], ['a man standing next to a young girl on a grassy hillside', 'a man in a black jacket'], ['woman in coveralls', '

# Fine tune like you pretrain
In the following we try to fine tune CLIP image and text encoders using contrastive learning as proposed by the original paper.

In [133]:
class FLYP_CLIP(nn.Module):
  def __init__(self, device=device):  #TODO: aggiungere device=device anche nelle architetture dello standard fine tuning

    super().__init__()

    model, preprocess = clip.load('RN50')

    # freeze all pretrained layers by setting requires_grad=False
    for param in model.parameters():
      param.requires_grad = False

    self.clip_visual_encoder = model.encode_image
    self.clip_text_encoder = model.encode_text
    self.clip_visual_preprocess = preprocess
    self.clip_text_preprocess = clip.tokenize

    self.visual_encoder_linearHead = nn.Linear(1024, 1024)
    self.text_encoder_linearHead = nn.Linear(1024, 1024)

    # the temperature parameter is added as suggested by the original paper in order to prevent training instability
    self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

  # preprocess input prompts as required by the visual encoder
  def visual_preprocess(self, _imgs):
    prep_images = torch.stack([
        self.clip_visual_preprocess(i)
        for i in _imgs
    ]).to(device)

    return prep_images

  # preprocess text prompts as required by the text encoder
  def text_preprocess(self, _txts):
    prep_texts = self.clip_text_preprocess(_txts)

    return prep_texts

  # visual encoder
  def visual_encoder(self, image):
    with torch.no_grad():
      clipFeatures = self.clip_visual_encoder(image)

    x = F.relu(clipFeatures)
    x = self.visual_encoder_linearHead(x)

    return x

  # text encoder
  def text_encoder(self, text):
    with torch.no_grad():
      clipFeatures = self.clip_text_encoder(text)

    x = F.relu(clipFeatures)
    x = self.text_encoder_linearHead(x)

    return x

  def forward(self, image, text):
    with torch.no_grad():
      image_pre = self.visual_preprocess(image)
      text_pre = self.text_preprocess(text)

    image_features = self.visual_encoder(image_pre)
    text_features = self.text_encoder(text_pre)

    return image_features, text_features, self.logit_scale.exp()

In [134]:
def get_optimizer(model, _lr, _wd, _momentum):
  optimizer = torch.optim.SGD(  params = model.parameters(),
                                lr = _lr,
                                weight_decay = _wd,
                                momentum = _momentum)
  return optimizer

In [135]:
class ClipLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def get_ground_truth(self, num_logits):
        labels = torch.arange(num_logits)
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logit_scale * text_features @ image_features.T

        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale):

        # compute logits per image and logits per text
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        # get ground truth labels for the computation of the cross entropy loss
        labels = self.get_ground_truth(logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        return total_loss

In [136]:
def training_step(
    model: torch.nn.Module,                   # neural network to be trained
    data_loader: torch.utils.data.DataLoader, # data loader to be iterated
    loss_fn: torch.nn.Module,                 # loss function
    optimizer: torch.optim.Optimizer,         # optimizer
    device: torch.device = device             #target device
):
  train_loss = 0.0

  model.to(device)
  model.train()

  for batch_idx, (cropss, promptss) in tqdm(enumerate(data_loader)):

    # for this implementation we consider only one prompt for each crop
    model_input_crops = [c[0] for c in cropss]
    model_input_prompts = [p[0] for p in promptss]

    # send data to target device
    ####cropss = cropss.to(device)
    ####promptss = promptss.to(device)

    # forward computation
    model_out = model(model_input_crops, model_input_prompts)
    image_features = model_out[0]
    text_features = model_out[1]
    logit_scale = model_out[2]

    # calculate loss
    loss = loss_fn(image_features, text_features, logit_scale)
    train_loss += loss

    # optimizer zero grad
    optimizer.zero_grad()

    # loss backward
    loss.backward()

    # optimizer step
    optimizer.step()

    # Note: we clamp to 4.6052 = ln(100), as in the original paper.
    with torch.no_grad():
        model.logit_scale.clamp_(0, math.log(100))


  # Calculate loss per epoch and print out what's happening
  train_loss /= len(data_loader)
  print(f"Train loss: {train_loss:.5f}\n")
  return train_loss

In [137]:
def test_step(
    model: torch.nn.Module,                   # neural network to be evaluated
    data_loader: torch.utils.data.DataLoader, # data loader to be iterated
    loss_fn: torch.nn.Module,                 # loss function
    device: torch.device = device             #target device
):
  test_loss = 0.0

  model.to(device)
  model.eval()

  with torch.inference_mode():
    #for batch_idx, cropss, promptss in tqdm(enumerate(data_loader)):
    for batch_idx, (cropss, promptss) in tqdm(enumerate(data_loader)):

      # for this implementation we consider only one prompt for each crop
      model_input_crops = [c[0] for c in cropss]
      model_input_prompts = [p[0] for p in promptss]

      # send data to target device
      ####cropss = cropss.to(device)
      ####promptss = promptss.to(device)

      # forward computation
      model_out = model(model_input_crops, model_input_prompts)
      image_features = model_out[0]
      text_features = model_out[1]
      logit_scale = model_out[2]

      # calculate loss
      loss = loss_fn(image_features, text_features, logit_scale)
      test_loss += loss

    test_loss /= len(data_loader)
    print(f"Test loss: {test_loss:.5f}\n")
    return test_loss

In [138]:
# instantiate the network and move it to the chosen device
net = FLYP_CLIP().to(device)

In [139]:
# tensorboard logging utilities
def log_values(writer, step, loss, prefix):
  writer.add_scalar(f"{prefix}/loss", loss, step)

In [140]:
# setting a manual seed allow us to provide reprudicible results in this notebook
torch.manual_seed(42)

# create a logger for the experiment
writer = SummaryWriter(log_dir="runs/exp1")

BATCH_SIZE = 3
LIMIT = 5 * BATCH_SIZE
NUM_WORKERS = 1

# get dataset instance
train_dataset = Coco4CLIPDataset(split="train", limit=LIMIT)
test_dataset = Coco4CLIPDataset(split="test", limit=LIMIT)
val_dataset = Coco4CLIPDataset(split="val", limit=LIMIT)
print(f"LEN_TRAIN_DATASET: {len(train_dataset)}, LEN_TEST_DATASET: {len(test_dataset)}, LEN_VALIDATION_DATASET: {len(val_dataset)}")

# get dataloaders
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")
train_loader: DataLoader[tuple[list[PIL.Image], list[str]]] = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=unzip,
    shuffle=True,
)
test_loader: DataLoader[tuple[list[PIL.Image], list[str]]] = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=unzip,
    shuffle=False,
)
val_loader: DataLoader[tuple[list[PIL.Image], list[str]]] = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=unzip,
    shuffle=False,
)
print(f"LEN_TRAIN_DATALOADER: {len(train_loader)}, LEN_TEST_DATALOADER: {len(val_loader)}, LEN_VALIDATION_DATALOADER: {len(test_loader)}")

# instantiate the optimizer
learning_rate = 0.01
weight_decay = 0.000001
momentum = 0.9
optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)

# define the cost function
cost_function = ClipLoss().to(device)

print('Before training:')
train_loss = test_step(model = net,
        data_loader = train_loader,
        loss_fn = cost_function)
val_loss = test_step(model = net,
        data_loader = val_loader,
        loss_fn = cost_function)
test_loss = test_step(model = net,
        data_loader = test_loader,
        loss_fn = cost_function)

# log to TensorBoard
log_values(writer, -1, train_loss, "train")
log_values(writer, -1, val_loss, "validation")
log_values(writer, -1, test_loss, "test")

print('\tTraining loss {:.5f}'.format(train_loss))
print('\tValidation loss {:.5f}'.format(val_loss))
print('\tTest loss {:.5f}'.format(test_loss))
print('-----------------------------------------------------')

# measure time
train_time_start = timer()

EPOCHS = 3
for epoch in tqdm(range(EPOCHS)):
    train_loss = training_step(
        model = net,
        data_loader = train_loader,
        loss_fn = cost_function,
        optimizer = optimizer
    )

    val_loss = test_step(
        model = net,
        data_loader = val_loader,
        loss_fn = cost_function
    )

    # logs to TensorBoard
    log_values(writer, epoch, train_loss, "train")
    log_values(writer, epoch, val_loss, "validation")

    print('Epoch: {:d}'.format(epoch+1))
    print('\tTraining loss {:.5f}'.format(train_loss))
    print('\tValidation loss {:.5f}'.format(val_loss))
    print('-----------------------------------------------------')

train_time_end = timer()
total_train_time_model_1 = print_train_time(start=train_time_start,
                                            end=train_time_end,
                                            device=device)
# compute final evaluation results
print('After training:')
train_loss = test_step(model = net,
        data_loader = train_loader,
        loss_fn = cost_function)
val_loss = test_step(model = net,
        data_loader = val_loader,
        loss_fn = cost_function)
test_loss = test_step(model = net,
        data_loader = test_loader,
        loss_fn = cost_function)

# log to TensorBoard
log_values(writer, EPOCHS, train_loss, "train")
log_values(writer, EPOCHS, val_loss, "validation")
log_values(writer, EPOCHS, test_loss, "test")

print('\tTraining loss {:.5f}'.format(train_loss))
print('\tValidation loss {:.5f}'.format(val_loss))
print('\tTest loss {:.5f}'.format(test_loss))
print('-----------------------------------------------------')

# closes the logger
writer.close()

LEN_TRAIN_DATASET: 15, LEN_TEST_DATASET: 15, LEN_VALIDATION_DATASET: 15
Creating DataLoader's with batch size 3 and 1 workers.
LEN_TRAIN_DATALOADER: 5, LEN_TEST_DATALOADER: 5, LEN_VALIDATION_DATALOADER: 5
Before training:


5it [00:05,  1.12s/it]


Test loss: 1.43407



5it [00:05,  1.02s/it]


Test loss: 1.29528



5it [00:06,  1.31s/it]


Test loss: 1.27050

	Training loss 1.43407
	Validation loss 1.29528
	Test loss 1.27050
-----------------------------------------------------


  0%|          | 0/3 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:01,  1.05s/it][A
2it [00:02,  1.09s/it][A
3it [00:03,  1.31s/it][A
4it [00:04,  1.21s/it][A
5it [00:05,  1.17s/it]


Train loss: 1.80077




0it [00:00, ?it/s][A
1it [00:01,  1.02s/it][A
2it [00:02,  1.16s/it][A
3it [00:04,  1.42s/it][A
4it [00:05,  1.48s/it][A
5it [00:06,  1.40s/it]
 33%|███▎      | 1/3 [00:12<00:25, 12.91s/it]

Test loss: 1.79844

Epoch: 1
	Training loss 1.80077
	Validation loss 1.79844
-----------------------------------------------------



0it [00:00, ?it/s][A
1it [00:01,  1.72s/it][A
2it [00:03,  1.59s/it][A
3it [00:04,  1.34s/it][A
4it [00:05,  1.23s/it][A
5it [00:06,  1.28s/it]


Train loss: 2.31575




0it [00:00, ?it/s][A
1it [00:01,  1.03s/it][A
2it [00:02,  1.03s/it][A
3it [00:03,  1.32s/it][A
4it [00:05,  1.45s/it][A
5it [00:06,  1.32s/it]
 67%|██████▋   | 2/3 [00:25<00:12, 12.98s/it]

Test loss: 2.28227

Epoch: 2
	Training loss 2.31575
	Validation loss 2.28227
-----------------------------------------------------



0it [00:00, ?it/s][A
1it [00:01,  1.08s/it][A
2it [00:02,  1.07s/it][A
3it [00:03,  1.06s/it][A
4it [00:04,  1.06s/it][A
5it [00:05,  1.06s/it]


Train loss: 3.21034




0it [00:00, ?it/s][A
1it [00:01,  1.04s/it][A
2it [00:02,  1.04s/it][A
3it [00:03,  1.04s/it][A
4it [00:04,  1.07s/it][A
5it [00:05,  1.19s/it]
100%|██████████| 3/3 [00:37<00:00, 12.41s/it]


Test loss: 3.36217

Epoch: 3
	Training loss 3.21034
	Validation loss 3.36217
-----------------------------------------------------
Train time on cpu: 37.231 seconds
After training:


5it [00:05,  1.17s/it]


Test loss: 4.83593



5it [00:05,  1.02s/it]


Test loss: 3.36217



5it [00:06,  1.32s/it]

Test loss: 2.81674

	Training loss 4.83593
	Validation loss 3.36217
	Test loss 2.81674
-----------------------------------------------------



