# Exponential decay sinusoidal
scratch for an exponentially decaying cosine learning rate

In [None]:
import numpy as np
import matplotlib.pyplot as plt
def decay_lr(minlr, maxlr, alpha, beta, E, t):
    A = maxlr - minlr
    sine_term = 0.5 * A * (np.cos(beta * t * 2 * np.pi) + 1)
    exp_term = np.exp(-alpha * t / E)
    return minlr + sine_term * exp_term

fig1, ax1 = plt.subplots(figsize=(12, 5))
epochs = 100
min_lr = 0.001
max_lr = 0.001
alpha = .75
beta = .2
t = np.arange(epochs, step=0.001)
exp_term = decay_lr(min_lr, max_lr, alpha, 0, epochs, t)
sine_term = decay_lr(min_lr, max_lr, 0, beta, epochs, t)
final = decay_lr(min_lr, max_lr, alpha, beta, epochs, t)
ax1.plot(t, exp_term, label='Exponential')
ax1.plot(t, sine_term, label='Sinusoid')
ax1.plot(t, final, label='Decaying cosine')
# ax1.hlines(min(final), 0, epochs, colors='k', label='minimum')
ax1.set(ylim=(0, 1.075*max_lr))
ax1.legend(loc='upper right')
fig1.tight_layout()
plt.show()

# Cosine Annealing with Warm Restarts

<!-- ![](https://production-media.paperswithcode.com/methods/Screen_Shot_2020-05-30_at_5.46.29_PM.png) -->
<img align="left" src=https://production-media.paperswithcode.com/methods/Screen_Shot_2020-05-30_at_5.46.29_PM.png style="width:400px;"/>

#### From *Papers with Code*:
Cosine Annealing is a type of learning rate schedule that has the effect of starting with a large learning rate that is relatively rapidly decreased to a minimum value before being increased rapidly again. The resetting of the learning rate acts like a simulated restart of the learning process and the re-use of good weights as the starting point of the restart is referred to as a "warm restart" in contrast to a "cold restart" where a new set of small random numbers may be used as a starting point.


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import math

def warm_restarts(eta_min, eta_max, T_0, T_mult, epoch):
    if epoch >= T_0:
        if T_mult == 1:
            T_cur = epoch % T_0
        else:
            n = int(math.log((epoch / T_0 * (T_mult - 1) + 1), T_mult))
            T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1)
            T_i = T_0 * T_mult ** (n)
    else:
        T_i = T_0
        T_cur = epoch
    lr = eta_min + (eta_max - eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) / 2
    return lr

T_0 = 10
T_mult = 2
eta_min = 0.0
base_lr = 0.001  # eta_max

# carst = eta_min + (base_lr - eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) / 2
epochs = range(100)
carst = [warm_restarts(eta_min, base_lr, T_0, T_mult, epoch=e) for e in epochs]
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(epochs, carst)
ax.set(xlabel='T_cur', ylabel='Learning Rate')
plt.show()

# Cosine Annealing (no warm restarts)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def cosine_annealing(T_max, eta_min, eta_max, epoch, group_lr=None):
    if group_lr is None:
        group_lr = eta_max
    if epoch == 0:
        return group_lr
    elif (epoch - 1 - T_max) % (2 * T_max) == 0:
        new_lr = group_lr + 0.5 * (eta_max - eta_min) * (1 - np.cos(np.pi / T_max))
    else:
        new_lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(epoch * np.pi / T_max))
    return new_lr

T_max = 100
epochs = range(1000)
eta_min = 0.0
eta_max = 0.01

cosann = [cosine_annealing(T_max, eta_min, eta_max, 0)]

for e in epochs[1:]:
    # print(f'cosann[-1] = {cosann[-1]}')
    cosann.append(cosine_annealing(T_max, eta_min, eta_max, e, cosann[-1]))

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(epochs, cosann)
ax.set(xlabel='T_cur', ylabel='Learning Rate')
plt.show()


In [None]:
T_max = 10
for epoch in range(100):
    if (epoch - 1 - T_max) % (2 * T_max) == 0:
        print(f'True for {epoch}')

In [None]:
from ignite.metrics import *
import os
from src.utils import import_spm, plot_tensors

metric = SSIM(data_range=1.0)
folder = "results/denoisedredux0.9-raw-02061403/"
all_results = [f for f in os.listdir(folder) if "montage" not in f]
ins = [import_spm(os.path.join(folder, i))[1] for i in all_results if "noisy" in i]
preds = [import_spm(os.path.join(folder, p))[1] for p in all_results if "denoised" in p]
targs = [import_spm(os.path.join(folder, t))[1] for t in all_results if "target" in t]
for p, t in zip(preds, targs):
    metric.update((p.unsqueeze(dim=0), t.unsqueeze(dim=0)))
    val = metric.compute()
    print(f'updated result = {val}')
    # plot_tensors(p, t, titles=['Pred', 'Targ'], small_fig=True)

# metric.reset()
print(f'final result = {metric.compute()}')

In [None]:
# import torch
# import random
# import torchvision.transforms.functional as F
# from torch import Tensor

# from typing import Union

# class FunctionalRandomResizedCrop(torch.nn.Module):
#     """ Random square cropping class that sets a fixed crop location for cropping pairs of images.
#     """
    
#     def __init__(self, size: int, top: int, left: int):
#         super().__init__()
#         # if type(size) is tuple and len(size) > 2:
#         #     raise ValueError('Please only provide two dimensions (h, w) for size. ')
#         # elif type(size) is int:
#         #     self.size = (size, size)
#         # else:
#         self.size = size
#         self.top = top
#         self.left = left

#     def forward(self, img):
#         _, h, w = F.get_dimensions(img)
#         if self.size > h or self.size > w:

# class CosineExpDecayLR(LRScheduler):
#     def adjust_lr(optimizer, epoch, nb_epochs, learning_params=(0.0, 0.001, 6.5, 10.0), decay=True):
#     """ Adjusts optimizer learning rate to a sine wave.
#     :param optimizer: optimizer object
#     :param epoch: current epoch
#     :param learning_params: [min, max, alpha, beta] learning rate parameters
#     """
#     lr_min = learning_params[0]
#     A = learning_params[1] - learning_params[0]
#     alpha = learning_params[2]
#     beta = learning_params[3]
#     sine_term = 0.5 * A * (np.cos(beta * epoch * 2 * np.pi) + 1)
#     exp_term = np.exp(-alpha * epoch / nb_epochs)

#     if decay:
#         lr_new = lr_min + sine_term * exp_term
#     else:
#         lr_new = lr_min + sine_term
    
#     for group in optimizer.param_groups:
#         group['lr'] = lr_new
#     return optimizer

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# file = "ckpts/tinyimagenet-sinelr-ssim-raw/tinyimagenet-sinelr-ssim-rawl2/n2n-stats.json"
file = "ckpts/scheduler-tests/tinyimagenet-cawrlr-ssimredux0.99-rawl2/n2n-stats.json"
train_stats = pd.read_json(file)

def trainvalid_metric_plots(train_metric, valid_metric, metric_name):
    fig, ax = plt.subplots(dpi=200)
    if 'loss' in metric_name:
        if train_metric is not None and abs(train_metric[0]) > 100 * abs(train_metric[1]):
            temp_t = train_metric[1:]
            temp_v = valid_metric[1:]
            ylim=(0.98*min(train_metric.min(), valid_metric.min()), 1.02*max(temp_t.max(), temp_v.max()))
            ax.set_ylim(ylim[0], ylim[1])
            ax.text(0, 0.985*ylim[1], 
                    '*Note: epoch 1 value(s) out of bounds', 
                    fontsize='xx-small')
        elif abs(valid_metric[0]) > 100 * abs(valid_metric[1]):
            temp_v = valid_metric[1:]
            ylim=(0.98*valid_metric.min(), 1.02*temp_v.max())
            ax.set_ylim(ylim[0], ylim[1])
            ax.text(0, 0.985*ylim[1], 
                    '*Note: epoch 1 value(s) out of bounds', 
                    fontsize='xx-small')
    if train_metric is not None:
        ax.plot(range(1, len(train_metric) + 1), train_metric, label=f'Train {metric_name}')
    ax.plot(range(1, len(valid_metric) + 1), valid_metric, label=f'Valid {metric_name}')
    ax.set(xlabel='Epoch',
           ylabel=f'{metric_name}',
           title=f"{'Train and Valid' if train_metric is not None else 'Valid'} {metric_name}")
    ax.legend(loc='upper right')
    fig.tight_layout()
    plt.show()

trainvalid_metric_plots(train_stats['train_loss'], train_stats['valid_loss'], 'L2 - SSIMs + 1 loss')
# trainvalid_metric_plots('.', None, train_stats['train_loss'], 'L2 - SSIMs + 1 loss')


In [1]:
import torch
import torchvision.transforms as trf
from torchvision.transforms import functional as tvF
from time import sleep
from PIL import Image
from pathos.helpers import cpu_count
from pathos.pools import ProcessPool as Pool
from tqdm import tqdm
import os

In [3]:
def simple_image_import(filepath):
    with Image.open(filepath) as im_pil:
            im_pil.load()
    im_tensor = tvF.pil_to_tensor(im_pil)
    return os.path.basename(filepath), im_tensor    #(im_tensor, im_pil)


def crop_and_remove_padded(in_dict_item, out_dict, crop_size, pad_thresh=0.05):
    fname, im_tensor = in_dict_item
    c, h, w = im_tensor.shape
    white = (1 - pad_thresh) * im_tensor.max()   # get white padding threshold
    black = im_tensor.min() * (1 + pad_thresh)   # get black padding threshold

    # check for padding using top left and bottom right corners
    if not any([torch.all(im_tensor[:, :3, :3] < black), 
           torch.all(im_tensor[:, :3, :3] > white), 
           torch.all(im_tensor[:, -3:, -3:] < black), 
           torch.all(im_tensor[:, -3:, -3:] > white)]):
        # if it doesn't have padding, crop to output size with random resized crop (in case crop size is bigger)
        cropper = trf.RandomResizedCrop(crop_size)
        if any([h != crop_size, w != crop_size]):
            out_dict[fname] = cropper(im_tensor)
        else:
            out_dict[fname] = im_tensor


In [4]:
root_dir = "results/denoisedredux0.9-raw-02061403/"
save = "."

supported = ['.png', '.jpg', '.jpeg']
source_iter = tqdm([f for f in os.listdir(root_dir) if os.path.splitext(f)[1].lower() in supported], 
                    desc=f'Loading {os.path.basename(root_dir)} images', unit='img')
images_in = {name: obj for (name, obj) in [simple_image_import(os.path.join(root_dir, s)) for s in source_iter]}

Loading  images: 100%|██████████| 28/28 [00:01<00:00, 26.74img/s]


In [6]:
sizes = [min(t.shape[1:]) for _, t in images_in.items()]
images_out = {}

# using looping
min(sizes)
images_out = {}
for fname, im_tensor in images_in.items():
    crop_and_remove_padded((fname, im_tensor), images_out, min(sizes))


In [None]:
test = images_in['HS-20MG-raw-montage.png']
test_topleft = test[:, :3, :3]
boolean_test = torch.where(test_topleft > 250)
booltest2 = torch.all(test_topleft > 250).item()

In [15]:
from lightning.pytorch.trainer.states import *
from datetime import datetime, timedelta

test = {state: datetime.now() for state in TrainerFn}
print([test])
test['fit']

test2 = timedelta(0)
print(test2)

[{<TrainerFn.FITTING: 'fit'>: datetime.datetime(2024, 2, 22, 15, 8, 28, 506284), <TrainerFn.VALIDATING: 'validate'>: datetime.datetime(2024, 2, 22, 15, 8, 28, 506293), <TrainerFn.TESTING: 'test'>: datetime.datetime(2024, 2, 22, 15, 8, 28, 506295), <TrainerFn.PREDICTING: 'predict'>: datetime.datetime(2024, 2, 22, 15, 8, 28, 506297)}]
0:00:00
