In [1]:
import sys
sys.path.append('/opt/ml/code/stage1')

In [4]:
from stage1.utils import *
from stage1.modules.models import *
from stage1.modules.loss import *
from stage1.trainer import *
import stage1.data as data

import torchvision

import timm
from efficientnet_pytorch import EfficientNet
import albumentations as A

%matplotlib inline

# Utilities

In [5]:
'''
function to remove strange 'module' prefix and load the state dict
'''

from collections import OrderedDict
def load_and_fix_state_dict(state_path, prefix='module'):
    state_dict = torch.load(state_path)
    fixed_state_dict = OrderedDict()
    for name, param in state_dict.items():
        if name[:len(prefix)] == prefix:
            new_name = '.'.join(name.split('.')[1:])
            fixed_state_dict[new_name] = param
        else:
            raise TypeError(name)
            
    return fixed_state_dict

In [6]:
!nvidia-smi

Thu Apr  8 09:28:45 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P40           On   | 00000000:00:05.0 Off |                  Off |
| N/A   41C    P8    13W / 250W |      0MiB / 24451MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

# Basic Architectures

In [7]:
# l=np.load('/opt/ml/input/pretrained/F0_haiku.npz', allow_pickle=True)#, #fix_imports=True)#, , encoding='latin1')

In [8]:
hrnet_w18_small_v2_1a_all = timm.models.hrnet_w18_small_v2()
hrnet_w18_small_v2_1a_all.classifier = nn.Linear(hrnet_w18_small_v2_1a_all.classifier.in_features, 18)
setattr(hrnet_w18_small_v2_1a_all, 'name', '-'.join('hrnet_w18_small_v2_1a_all'.split('_')))
hrnet_w18_small_v2_1a_all_logger = CSVLogger()

In [9]:
effnet_b0_2da_age = timm.models.efficientnet_b0()
effnet_b0_2da_age.classifier = nn.Sequential(
    nn.Linear(effnet_b0_2da_age.classifier.in_features, 1024),
    nn.Dropout(0.5),
    nn.Linear(1024, 3)
)
setattr(effnet_b0_2da_age, 'name', 'effnet-b0-2da-age')
effnet_b0_2da_age_logger = CSVLogger()

In [10]:
effnet_b3_2da_age = timm.models.efficientnet_b3(True)
effnet_b3_2da_age.classifier = nn.Sequential(
    nn.Linear(effnet_b3_2da_age.classifier.in_features, 1024),
    nn.Dropout(0.5),
    nn.Linear(1024, 3)
)
setattr(effnet_b3_2da_age, 'name', 'effnet-b3-2da-age')
effnet_b0_2da_age_logger = CSVLogger()

In [11]:
effnet_b3_age = timm.models.efficientnet_b3(True)
effnet_b3_age.classifier = nn.Linear(
    effnet_b3_age.classifier.in_features, 3
)
setattr(effnet_b3_age, 'name', 'effnet-b3-1a-age-raw')

effnet_b3_age_logger = CSVLogger()

In [12]:
effnet_b3_gender = timm.models.efficientnet_b3(True)
effnet_b3_gender.classifier = nn.Linear(
    effnet_b3_gender.classifier.in_features, 2
)
setattr(effnet_b3_gender, 'name', 'effnet_b3_gender')

effnet_b3_gender_logger = CSVLogger()

In [13]:
effnet_b3_mask = timm.models.efficientnet_b3(True)
effnet_b3_mask.classifier = nn.Linear(
    effnet_b3_mask.classifier.in_features, 3
)
setattr(effnet_b3_mask, 'name', 'effnet_b3_mask')

effnet_b3_mask_logger = CSVLogger()

In [14]:
effnet_b0_1a = timm.models.efficientnet_b0(True)
effnet_b0_1a.classifier = nn.Linear(effnet_b0_1a.classifier.in_features, 18)
setattr(effnet_b0_1a, 'name', 'effnet-b0-1a')

effnet_b0_1a_logger = CSVLogger()

In [15]:
effnet_b0_2da = timm.models.efficientnet_b0(True)
effnet_b0_2da.classifier = nn.Sequential(
    nn.Linear(effnet_b0_2da.classifier.in_features, 1024),
    nn.Dropout(0.5),
    nn.Linear(1024, 18)
)
setattr(effnet_b0_2da, 'name', 'effnet-b0-2da')

effnet_b0_2da_logger = CSVLogger()

# Configurations

In [16]:
# new config
config = ConfigTree()

In [17]:
# system
config.system.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config.system.num_workers = 4

In [18]:
# path
config.path.input = os.path.join('/', 'opt', 'ml', 'input')
config.path.base = os.path.join(config.path.input, 'data')
config.path.train = os.path.join(config.path.base, 'train', 'all_')
config.path.test = os.path.join(config.path.base, 'eval', 'images')
config.path.valid = None

config.path.output = os.path.join('/', 'opt', 'ml', 'output')
config.path.models = os.path.join(config.path.output, 'models')
config.path.archive = os.path.join(config.path.models, 'archive')
config.path.configs = os.path.join(config.path.output, 'configs')
config.path.logs = os.path.join(config.path.output, 'logs')

config.path.pretrained = os.path.join(config.path.input, 'pretrained')

In [19]:
# data
config.data.valid_ratio = 0.2
config.data.valid_balanced = False
config.data.upscale = False
config.data.train_cutmix = False
config.data.cutmix_kernel = 256
config.data.soften_age = False
config.data.custom_augment = [
    A.Rotate(limit=15, always_apply=True),
    A.HorizontalFlip(always_apply=True),
    A.GaussNoise(var_limit=20, always_apply=True),
    A.ShiftScaleRotate(scale_limit=0.2, rotate_limit=15, always_apply=True)
]
config.data.preprocess = False
config.data.crop_size = None
config.data.resize = None
config.data.sampler = None   # not implemented yet
config.data.batch_size = 16

In [20]:
# hyperparameters for training
config.train.lr = ConfigBranch(
    base = 1e-3,
    backbone = 0,
    classifier = 0,
    scheduler = False,
    scheduler_kwarg = ConfigBranch(),
    few = 3,
    divider = 10,
    low_limit = 1e-8,
)
config.train.weight_decay = 1e-4
config.train.betas = (0.9, 0.999)
config.train.momentum = 0.9
config.train.nesterov = False
config.train.loss = ConfigBranch(
    criterion = 'CrossEntropyLoss',
#     criterion = FocalLoss(),
)
config.train.optimizer = ConfigBranch(
    name = 'AdamP',
    separate = False,
)

In [21]:
# monitor settings
config.train.num_epochs = 10
config.train.valid_iters = 0

config.train.valid_min = 1
config.train.test_min = 1
config.train.save_min = 1

config.train.valid_period = 1
config.train.test_period = 0
config.train.save_period = 1
config.train.shuffle_period = 0
config.train.plot_period = 5

config.train.valid_min_acc = 0
config.train.test_min_acc = 1
config.train.save_min_acc = 0
config.train.plot_min_acc = 0

config.train.logger = CSVLogger()

In [22]:
# model
config.model = ConfigBranch(
    model = torchvision.models.resnet18(False),
    teacher = None,
    classifying = 'all',
    state_path = ''
)
setattr(config.model.model, 'name', 'dummy_model')

# Trainers

In [24]:
trainer = Trainer(config)

# Models + Training

In [30]:
'''[STRATIFIED FOLD #0 였던 것] effnet-b3-multi (smaller)(CE)(AdamP)
>>> CHECKPOINTS
[lowest loss]()<- ""
[best f1 score]()<- ""
(epoch 1)
 /opt/ml/output/models/2021-04-08-09_20_14.317066_effnet-b3-multi+smaller+ce+adamp_0_upscaled.pth

'''

fold_idx = 0


effnet_b3_multi = timm.models.efficientnet_b3(True)
effnet_b3_multi.classifier = MultiheadClassifier(
    in_features=effnet_b3_multi.classifier.in_features, 
    out_features=18
)
setattr(effnet_b3_multi, "name", f"effnet-b3-multi+smaller+ce+adamp_{fold_idx}")
effnet_b3_multi_logger = CSVLogger()

In [None]:
'''[STRATIFIED FOLD #0] effnet-b0 (smaller)(CE)(AdamP)
>>> CHECKPOINTS
[lowest loss]()<- ""
[best f1 score]()<- ""
(epoch 1)
 /opt/ml/output/models/2021-04-08-09_20_14.317066_effnet-b3-multi+smaller+ce+adamp_0_upscaled.pth

'''

fold_idx = 0

effnet_b3_mul = timm.models.efficientnet_b3(True)
effnet_b3_multi.classifier = MultiheadClassifier(
    in_features=effnet_b3_multi.classifier.in_features, 
    out_features=18
)
setattr(effnet_b3_multi, "name", f"effnet-b0+smaller+ce+adamp_{fold_idx}")
effnet_b3_multi_logger = CSVLogger()

In [31]:
branch = ConfigBranch(
    train_path = f'/opt/ml/input/data/train/smaller_',
    valid_path = None,
    model = effnet_b3_multi,
    logger = effnet_b3_multi_logger,
    classifying = 'all', 
    upscale = True,
    cutmix = False,
    criterion = 'CrossEntropyLoss',
    optimizer = 'AdamP',
    lr = 1e-3,
    state_path = '/opt/ml/output/models/2021-04-08-09_20_14.317066_effnet-b3-multi+smaller+ce+adamp_0_upscaled.pth',
)

config.path.train = branch.train_path
config.path.valid = branch.valid_path
config.model.model = branch.model
config.train.logger = branch.logger
config.model.classifying = branch.classifying
config.data.upscale = branch.upscale
config.data.train_cutmix = branch.cutmix
config.train.optimizer.name = branch.optimizer
config.train.lr.base = branch.lr
config.train.weight_decay = 1e-4
config.train.loss.criterion = branch.criterion
config.model.state_path = branch.state_path

try:
    trainer().load_state_dict_to_model()
except KeyError:
    trainer().model.load_state_dict(load_and_fix_state_dict(config.model.state_path))

# trainer.train_and_save()
trainer.valid()
trainer.infer_and_save_csv()

Loaded state dict to model.
[Valid 001]  Loss: 0.06642,  Acc: 97.623,  F1 Score: 0.97033
effnet-b3-multi+smaller+ce+adamp_0: End of evaluation (01:59)
Saved result: /opt/ml/output/effnet-b3-multi+smaller+ce+adamp_0_2021-04-08-09_54_52.937448.csv


'/opt/ml/output/effnet-b3-multi+smaller+ce+adamp_0_2021-04-08-09_54_52.937448.csv'

In [25]:
'''effnet-b3-age (smaller)(ArcFace)(AdamP)
>>> CHECKPOINTS
[lowest loss]()<- ""
[best f1 score]()<- ""
(1 epoch)
/opt/ml/output/models/2021-04-08-09_41_48.519864_effnet-b4-age+smaller+ce+adamp_.pth
'''

effnet_b4_age = EfficientNet.from_pretrained('efficientnet-b3')
effnet_b4_age._fc = nn.Linear(effnet_b4_age._fc.in_features, 3)
setattr(effnet_b4_age, "name", "effnet-b4-age+smaller+ce+adamp")
effnet_b4_age_logger = CSVLogger()

In [None]:
branch = ConfigBranch(
    train_path = f'/opt/ml/input/data/train/smaller_',
    valid_path = None,
    model = effnet_b4_age,
    logger = effnet_b4_age_logger,
    classifying = 'age', 
    upscale = False,
    cutmix = False,
    criterion = 'ArcFaceLoss',
    optimizer = 'AdamP',
    lr = 1e-3,
    state_path = '',
)

config.path.train = branch.train_path
config.path.valid = branch.valid_path
config.model.model = branch.model
config.train.logger = branch.logger
config.model.classifying = branch.classifying
config.data.upscale = branch.upscale
config.data.train_cutmix = branch.cutmix
config.train.optimizer.name = branch.optimizer
config.train.lr.base = branch.lr
config.train.weight_decay = 1e-4
config.train.loss.criterion = branch.criterion
config.model.state_path = branch.state_path

try:
    trainer().load_state_dict_to_model()
except KeyError:
    trainer().model.load_state_dict(load_and_fix_state_dict(config.model.state_path))

trainer.train_and_save()
# trainer.valid()
# trainer.infer_and_save_csv()

[INFO]
model=effnet-b4-age+smaller+ce+adamp
device=cuda:0(Tesla P40)
data size=(12959 + 3239), batch size=16
optimizer.name=AdamP
epochs=10, lr=0.001, weight_decay=0.0001, betas=(0.9, 0.999)

Start of traning.
[Epoch 002]  Loss: 3.74217,  Acc: 93.132,  F1 Score: 0.78218  (11:23)
[Valid 002] (Batch #077)  Loss: 5.69441,  Acc: 88.6220

In [None]:
'''effnet-b3-age-soften (smaller)(ArcFace)(AdamP)
>>> CHECKPOINTS
[lowest loss]()<- ""
[best f1 score]()<- ""
'''

effnet_b3_age_s = EfficientNet.from_pretrained('efficientnet-b3')
effnet_b3_age_s._fc = nn.Linear(effnet_b3_age_s._fc.in_features, 3)
setattr(effnet_b3_age_s, "name", "effnet-b3-age-soft+smaller+ce+adamp")
effnet_b3_age_s_logger = CSVLogger()

branch = ConfigBranch(
    train_path = f'/opt/ml/input/data/train/smaller_',
    valid_path = None,
    model = effnet_b3_age_s,
    logger = effnet_b3_age_s_logger,
    classifying = 'age', 
    soften_age = True,
    upscale = False,
    cutmix = False,
    criterion = 'ArcFaceLoss',
    optimizer = 'AdamP',
    lr = 1e-3,
    state_path = '',
)

config.path.train = branch.train_path
config.path.valid = branch.valid_path
config.model.model = branch.model
config.train.logger = branch.logger
config.model.classifying = branch.classifying
config.data.upscale = branch.upscale
config.data.train_cutmix = branch.cutmix
config.data.soften_age = branch.soften_age
config.train.optimizer.name = branch.optimizer
config.train.lr.base = branch.lr
config.train.weight_decay = 1e-4
config.train.loss.criterion = branch.criterion
config.model.state_path = branch.state_path

try:
    trainer().load_state_dict_to_model()
except KeyError:
    trainer().model.load_state_dict(load_and_fix_state_dict(config.model.state_path))

trainer.train_and_save()
# trainer.valid()
# trainer.infer_and_save_csv()

In [25]:
'''
Student: EfficientNet-b0, cutmix_random (KDLoss)
'''

effnet_b0_student = timm.models.efficientnet_b0(True)
effnet_b0_student.classifier = MultiheadClassifier(
    in_features=effnet_b0_student.classifier.in_features, 
    out_features=18
)
setattr(effnet_b0_student, "name", f"effnet-b0-multi_student_+smaller+ce+adamp")
effnet_b0_student_logger = CSVLogger()

branch = ConfigBranch(
    train_path = f'/opt/ml/input/data/train/smaller_',
    valid_path = None,
    teacher = effnet_b3_multi,
    model = effnet_b0_student,
    logger = effnet_b0_student_logger,
    classifying = 'all', 
    upscale = False,
    cutmix = 'random',
    criterion = KDLoss(T=2.5, alpha=0.75, num_classes=18),
    optimizer = 'AdamP',
    lr = 1e-3,
    state_path = '',
)

config.path.train = branch.train_path
config.path.valid = branch.valid_path
config.model.model = branch.model
config.model.teacher = branch.teacher
config.train.logger = branch.logger
config.model.classifying = branch.classifying
config.data.upscale = branch.upscale
config.data.train_cutmix = branch.cutmix
config.train.optimizer.name = branch.optimizer
config.train.lr.base = branch.lr
config.train.weight_decay = 1e-4
config.train.loss.criterion = branch.criterion
config.model.state_path = branch.state_path


try:
    trainer().load_state_dict_to_model()
except KeyError:
    trainer().model.load_state_dict(load_and_fix_state_dict(config.model.state_path))

trainer.train_and_save()
# trainer.valid()
# trainer.infer_and_save_csv()

[INFO]
model=effnet-b0-multi_student_+smaller+ce+adamp_0
device=cuda:0(Tesla P40)
data size=(12166 + 4032), batch size=16
optimizer.name=AdamP
epochs=10, lr=0.001, weight_decay=0.0001, betas=(0.9, 0.999)

Start of traning.
[Epoch 001]  Loss: 45.81765,  Acc: 79.385,  F1 Score: 0.46161  (10:37)
[Valid 001]  

TypeError: forward() missing 1 required positional argument: 'targets'

## Single model TTA inference

In [None]:
branch = ConfigBranch(
    model = effnet_b3_2da_age,
    classifying = 'age',
    state_path = '/opt/ml/output/models/2021-04-06 02:22:12.233049_effnet_b3_2d_age_upscaled.pth'
)      

config.model.model = branch.model
config.model.classifying = branch.classifying
config.model.state_path = branch.state_path
trainer().load_state_dict_to_model()
predictions = trainer.infer_with_simple_tta()

predictions.head()

In [None]:
csv_name = f"effnet-b3-2da-age-upscaled-tta__{filename_from_datetime(datetime.today().time())}.csv"
csv_path = os.path.join(config.path.output, csv_name)

predictions.to_csv(csv_path)
print(f"Saved result: {csv_path}")

# Ensemble

## Single criterion revision

In [27]:
label_csv = '/opt/ml/output/effnet-b0-multi-raw+ce(as)+adamp_2021-04-07-21_21_25.613439.csv'

age_csv = '/opt/ml/output/effnet-b0-1a-age-soft+raw+ce+adamp_2021-04-08-01_08_05.839885.csv'
gender_csv = ''
mask_csv = ''

In [28]:
if label_csv:
    labels = pd.read_csv(label_csv, index_col='ImageID')
elif not (age_csv and gender_csv and mask_csv):
    raise ValueError

if age_csv:
    ages = pd.read_csv(age_csv, index_col='ImageID')
else:
    ages = pd.DataFrame(data={'ans': labels.ans.map(data.age_from_label)}, index=labels.index)
    
if gender_csv:
    genders = pd.read_csv(gender_csv, index_col='ImageID')
else:
    genders = pd.DataFrame(data={'ans': labels.ans.map(data.gender_from_label)}, index=labels.index)
    
if mask_csv: 
    masks = pd.read_csv(mask_csv, index_col='ImageID')
else:
    masks = pd.DataFrame(data={'ans': labels.ans.map(data.mask_from_label)}, index=labels.index)

info = pd.read_csv('/opt/ml/output/info.csv', index_col='ImageID')

assert (labels.index == info.index).all()
assert (ages.index == info.index).all()
assert (genders.index == info.index).all()
assert (masks.index == info.index).all()

In [29]:
labels.head()

Unnamed: 0_level_0,ans
ImageID,Unnamed: 1_level_1
cbc5c6e168e63498590db46022617123f1fe1268.jpg,13
0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,1
b549040c49190cedc41327748aeb197c1670f14d.jpg,13
4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,13
248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12


In [30]:
ages.ans = ages.ans.map(data.age_from_label)
ages.ans.value_counts()

0    5428
2    4051
1    3121
Name: ans, dtype: int64

In [31]:
new_labels = pd.DataFrame(
    data={'ans': data.labels(ages.ans, genders.ans, masks.ans)},
    index=labels.index
)
new_labels.head()

Unnamed: 0_level_0,ans
ImageID,Unnamed: 1_level_1
cbc5c6e168e63498590db46022617123f1fe1268.jpg,14
0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,2
b549040c49190cedc41327748aeb197c1670f14d.jpg,14
4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,13
248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12


In [32]:
(new_labels - labels).sum()

ans    2719
dtype: int64

In [34]:
prefix = "[age]effnet-b0-1a-age-soften+[gender][mask]effnet-b0-multi-all-raw"

csv_name = f"{prefix}__{filename_from_datetime(datetime.today().time())}.csv"
csv_path = os.path.join(config.path.output, csv_name)

new_labels.to_csv(csv_path)

print(f"Saved result: {csv_path}")

Saved result: /opt/ml/output/[age]effnet-b0-1a-age-soften+[gender][mask]effnet-b0-multi-all-raw__01_09_45.404480.csv


## Ensemble of n(x/w)

In [None]:
models = [
    (),
]

for model, state_path, _ in models:
    model.load_state_dict(torch.load(state_path, map_location=config.system.device))

print("<All keys have been matched.>")

In [None]:
ensemble_and_infer_test(trainer, models, weighted=False)

## Just to save each inference to csv

In [None]:
for model, _, classifying in models:
    print(f"{branch.model.name}", end="")
    config.model.model = model
    config.model.classifying = classifying
    trainer().infer_and_save_csv()

# Result check and visualization

In [28]:
pd.read_csv('/opt/ml/output/archive/effnet-b0-1a-smaller-raw_2021-04-07 19_24_27.092316.csv').head()

Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,14
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,2
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,14
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,14
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12


In [None]:
test_size = len(ages)
num_samples = 100

fig, axes = plt.subplots(
    num_samples // 5, 5, 
    figsize=(10, 5 * num_samples // 8)
)

for i, idx in enumerate(np.random.randint(0, test_size, num_samples)):
    filename = ages.index[idx]
    prediction = ages.iloc[idx].ans
    label = labels.loc[filename].ans
    
    filepath = os.path.join(config.path.test, filename)
    image = cv2.imread(filepath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image / 255.
    
        
    if prediction % 3 == 0:
        age = 'under 30'
    elif prediction % 3 == 1:
        age = '30 to 60'
    else:
        age = '60 or more'
        
    if label % 3 == 0:
        age_l = 'under 30'
    elif prediction % 3 == 1:
        age_l = '30 to 60'
    else:
        age_l = '60 or more'
    
    r, c = divmod(i, 5)
    axes[r, c].imshow(image)
    axes[r, c].set_title(f"{age_l} -> {age}", fontsize=12)
    axes[r, c].axis('off')
    
plt.tight_layout()
plt.show()

In [None]:
predictions = labels

test_size = len(predictions)
num_samples = 100

fig, axes = plt.subplots(
    num_samples // 5, 5, 
    figsize=(10, 5 * num_samples // 7)
)

for i, idx in enumerate(np.random.randint(0, test_size, num_samples)):
    filename = predictions.index[idx]
    prediction = predictions.iloc[idx].ans
    filepath = os.path.join(config.path.test, filename)
    image = cv2.imread(filepath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image / 255.
    
    if prediction // 6 == 0:
        mask = 'correct'
    elif prediction // 6 == 1:
        mask = 'incorrect'
    elif prediction // 6 == 2:
        mask = 'not wear'
    else:
        raise "?"
    
    if (prediction // 3) % 2 == 0:
        gender = 'male'
    else:
        gender = 'female'
        
    if prediction % 3 == 0:
        age = 'under 30'
    elif prediction % 3 == 1:
        age = '30 to 60'
    else:
        age = '60 or more'
    
    r, c = divmod(i, 5)
    axes[r, c].imshow(image)
    axes[r, c].set_title(f"{mask}\n{gender}\n{age}", fontsize=12)
    axes[r, c].axis('off')
    
plt.tight_layout()
plt.show()

# Augmentation check

In [48]:
if config.data.upscale:
    train_set = trainer.train_loader.dataset
    num_samples = 10
    sample_indices = np.random.randint(0, len(train_set), num_samples)
    fig, axes = plt.subplots(num_samples // 5, 5, figsize=(10, 5))
    for i, idx in enumerate(sample_indices):
        image = train_set[idx][0].permute(1, 2, 0).numpy()
        subfig = axes[divmod(i, 5)]
        subfig.axis('off')
        subfig.imshow(image)

    plt.tight_layout()
    plt.show()