In [1]:
! pip install pytorch-lightning==2.1.2
! pip3 install numpy==1.23.5
! pip install protobuf==3.20.*
! pip install onnx
! pip install wandb

[0m

In [2]:
import os
from datetime import datetime

import wandb
import torch
import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics

from model import Net, Model, DataModule

In [3]:
torch.set_float32_matmul_precision('high')

# Data

In [4]:
#TODO: Hydra
#Train/Val
train_dataset_name = 'COCO25K'
train_dataset_path = '/gsn/datasets/COCO25K/train'
val_dataset_path = '/gsn/datasets/COCO25K/val'

#Test
test_dataset_name = 'BSD100'
test_dataset_path_hr = '/gsn/datasets/test/BSD100/image_SRF_4'

# Train

In [5]:
#TODO: Hydra
IMAGE_SIZE = 100
SCALING_FACTOR = 4

NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
BATCH_SIZE = 200
NUM_WORKERS = 4

In [6]:
! wandb login --relogin 88a995e9bf3cca3f7e93c8c48c2bb5835f08305a

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [7]:
os.environ['WANDB_NOTEBOOK_NAME'] = '/gsn/main'

In [8]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss', #change to val_loss
    dirpath=MODEL_CKPT_PATH,
    filename=MODEL_CKPT,
    save_top_k=3,
    mode='min')

In [None]:
# Inicjalizacja wandb
now = datetime.now()
run_name = "train_" + now.strftime('%Y-%m-%d_%H:%M:%S')
tags = ["train"]
run = wandb.init(
    project='gsn_super_resolution',
    name=run_name,
    tags=tags,
    config={
        "dataset": train_dataset_name,
        "epochs": NUM_EPOCHS,
        "scaling_factor": SCALING_FACTOR,
        "image_size": IMAGE_SIZE,
    },
)
wandb_logger = WandbLogger(project='gsn_super_resolution')

# Inicjalizacja modelu
dm =  DataModule(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, 
                 image_size=IMAGE_SIZE, scaling_factor=SCALING_FACTOR,
                 train_dataset_path=train_dataset_path,
                 val_dataset_path=val_dataset_path,
                 test_dataset_path=test_dataset_path_hr)

dense_net = Net()
model = Model(model=dense_net, scaling_factor=SCALING_FACTOR, learning_rate=LEARNING_RATE)

# Inicjalizacja trenera
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, logger=wandb_logger, callbacks=[checkpoint_callback], log_every_n_steps=2, fast_dev_run=False)

# Trenowanie modelu
trainer.fit(model,datamodule=dm)

[34m[1mwandb[0m: Currently logged in as: [33mpiotrczernecki[0m ([33mgsn-sr[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory model/ exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                             | Params
----------------------------------------------------------------
0 | model      | Net                              | 5.5 M 
1 | train_ssim | StructuralSimilarityIndexMeasure | 0     
2 | val_ssim   | StructuralSimilarityIndexMeasure | 0     
3 | test_ssim  | StructuralSimilarityIndexMeasure | 0     
4 | train_psnr | PeakSignalNoi

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
artifact = wandb.Artifact('model', type='model')
artifact.add_dir(MODEL_CKPT_PATH)

run.log_artifact(artifact)
run.join()
wandb.finish()

# Test

In [None]:
# Inicjalizacja wandb
now = datetime.now()
run_name = "test_" + now.strftime('%Y-%m-%d_%H:%M:%S')
tags = ["test"]
run = wandb.init(
    project='gsn_super_resolution',
    name=run_name,
    tags=tags,
    config={
        "dataset": test_dataset_name,
        "scaling_factor": SCALING_FACTOR,
        "image_size": IMAGE_SIZE,
    },
)
wandb_logger = WandbLogger(project='gsn_super_resolution')

trainer = pl.Trainer(logger=wandb_logger, log_every_n_steps=2)

# Test
trainer.test(model=model, datamodule=dm)

# Zamknięcie wandb
wandb.finish()

# Archive model

In [None]:
model_archive_path = '/gsn/models_archive/'
if not os.path.exists(model_archive_path):
   os.makedirs(model_archive_path)

file_path = model_archive_path + "model_%s.onnx" % now.strftime('%Y-%m-%d_%H:%M:%S')

input_sample = torch.randn((1, 1, IMAGE_SIZE, IMAGE_SIZE))
model.to_onnx(file_path, input_sample, export_params=True)