In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

import starry.utils.env
from starry.vision.data.peris import PerisData


DATA_DIR = os.getenv('DATA_DIR')
data = PerisData(os.path.join(DATA_DIR, 'nse-100.zip'), os.path.join(DATA_DIR, 'labels-20220205.csv'), label_fields=['score'], shuffle=True)

it = iter(data)
source, labels = next(it)
print(labels)
plt.imshow((source[0] * 255).astype(np.uint8))
plt.show()

In [None]:
# PerisSimple

from torch.utils.data import DataLoader

from starry.vision.models.peris import PerisSimpleLoss


loader = DataLoader(data, batch_size=1, collate_fn=data.collateBatch)
it = iter(loader)

model = PerisSimpleLoss(backbone='efficientnet_b0')
batch = next(it)

loss, metrics = model(batch)
loss, metrics

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-simple.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)
source, labels = next(it)
print(labels)

source = source[0].permute(1, 2, 0).numpy()
plt.imshow((source * 255).astype(np.uint8))
plt.show()


In [None]:
# prediction
import os
import matplotlib.pyplot as plt
import numpy as np
import torch

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset
from starry.utils.model_factory import loadModel


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration('training/peris/20220217-peris-score-simple')
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

model = loadModel(config['model'])
if config['best']:
	checkpoint = torch.load(config.localPath(config['best']), map_location='cpu')
	model.load_state_dict(checkpoint['model'])
	print(f'checkpoint loaded: {config["best"]}')
model.eval()

it = iter(data)


In [None]:
# prediction iteration
with torch.no_grad():
	source, labels = next(it)
	pred = model(source)
	print(labels, pred)

source = source[0].permute(1, 2, 0).numpy()
plt.imshow((source * 255).astype(np.uint8))
plt.show()


In [None]:
# BalanceLabeledImage
import os
#import matplotlib.pyplot as plt
import numpy as np

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-simple-balance.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)
source, labels = next(it)
print(labels)

In [None]:
# BalanceLabeledImage iteration
import matplotlib.pyplot as plt

source, labels = next(it)
print(labels)

source = source[0].permute(1, 2, 0).numpy()
plt.imshow((source * 255).astype(np.uint8))
plt.show()

In [None]:
# dataset loading
import os
import numpy as np

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-balance-score_binary.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)
source, labels = next(it)
print(labels)

In [None]:
# dataset iteration
import matplotlib.pyplot as plt

source, labels = next(it)
print(labels)

source = source[0].permute(1, 2, 0).numpy()
plt.imshow((source * 255).astype(np.uint8))
plt.show()

In [None]:
# model test with loss
import os
import matplotlib.pyplot as plt
import numpy as np
import torch

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset
from starry.utils.model_factory import loadModel


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.createOrLoad('configs/peris-score-balance-score_binary.yaml')
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

model = loadModel(config['model'], postfix='Loss')

it = iter(data)

source, labels = next(it)
pred = model((source, labels))
print(labels, pred)


In [None]:
# model test with loss
import os
import matplotlib.pyplot as plt
import numpy as np
import torch

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset
from starry.utils.model_factory import loadModel


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.createOrLoad(r'training\peris\20220305-peris-score-balance-score_binary-b2')
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

model = loadModel(config['model'])
checkpoint = torch.load(config.localPath(config['best']), map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()


In [None]:

it = iter(data)

with torch.no_grad():
	source, labels = next(it)
	pred = model(source)
	print(labels, pred)

img = source[0].permute(1, 2, 0).numpy()
plt.imshow((img * 255).astype(np.uint8))
plt.show()


In [None]:
# PerisCaption
import os
import starry.utils.env
from starry.vision.data import PerisCaption


DATA_DIR = os.getenv('DATA_DIR')
data = PerisCaption(os.path.join(DATA_DIR, 'nse-100.zip'), os.path.join(DATA_DIR, 'labels-20220921.csv'))

it = iter(data)
next(it)

In [None]:
next(it)[1]