In [1]:
import os
import sys
from glob import glob
from io import BytesIO
from pathlib import Path

import matplotlib.pyplot as plt
import polars as pl
import requests
import torch
from diffusers import LDMSuperResolutionPipeline
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm

from src.config import cfg
from src.dir import create_dir
from src.seed import seed_everything

cfg.exp_number = Path().resolve().name
print(OmegaConf.to_yaml(cfg, resolve=True))

seed_everything(cfg.seed)
pl.Config.set_fmt_str_lengths(1000)


  from .autonotebook import tqdm as notebook_tqdm


exp_number: '014'
run_time: base
data:
  input_root: ../../data/input
  train_path: ../../data/input/train_features.csv
  test_path: ../../data/input/test_features.csv
  sample_submission_path: ../../data/input/sample_submission.csv
  img_root: ../../data/input/images
  img_upscaled_root: ../../data/input/images_upscaled
  json_root: ../../data/input/traffic_lights
  depth_root: ../../data/input/depth
  output_root: ../../data/output
  results_root: ../../results
  results_path: ../../results/014/base
seed: 319
n_splits: 4
target_cols:
- x_0
- y_0
- z_0
- x_1
- y_1
- z_1
- x_2
- y_2
- z_2
- x_3
- y_3
- z_3
- x_4
- y_4
- z_4
- x_5
- y_5
- z_5
cnn:
  model_name: resnet18
  size: 256
  pretrained: true
  in_chans: 20
  target_size: 18
  lr: 0.001
  num_epochs: 10
  batch_size: 32



polars.config.Config

In [2]:
# データの読み込み
train = pl.read_csv(cfg.data.train_path, try_parse_dates=True)
test = pl.read_csv(cfg.data.test_path, try_parse_dates=True)
sample_submission = pl.read_csv(cfg.data.sample_submission_path, try_parse_dates=True)

# データの結合(label encoding用)
train_test = pl.concat([train, test], how="diagonal")

# scene列を作成 → これでGroupKFoldする
train = train.with_columns(pl.col("ID").str.split("_").list[0].alias("scene"))
test = test.with_columns(pl.col("ID").str.split("_").list[0].alias("scene"))


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/ldm-super-resolution-4x-openimages"


In [4]:
# load model and scheduler
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/marumarukun/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/vqvae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/marumarukun/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/vqvae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  33%|███▎      | 1/3 [00:00<00:00,  9.20it/s]An error occurred while trying to fetch /home/marumarukun/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /home/marumarukun/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-o

In [5]:
def get_relative_path(path):
    return os.path.join(cfg.data.img_root, path)


# 画像へのパス
image_path_root_list = [
    get_relative_path("{ID}/image_t.png"),
    get_relative_path("{ID}/image_t-0.5.png"),
    get_relative_path("{ID}/image_t-1.0.png"),
]

for row in tqdm(train_test.iter_rows(named=True)):
    for image_path_root in image_path_root_list:
        img_pil = Image.open(image_path_root.format(ID=row["ID"]))
        upscaled_img = pipeline(img_pil, num_inference_steps=8, eta=1).images[0]
        # pred = depth_anything_v2(img_pil)

        upscaled_img_path = image_path_root.format(ID=row["ID"]).replace("images", "images_upscaled")
        os.makedirs(os.path.dirname(upscaled_img_path), exist_ok=True)
        upscaled_img.save(upscaled_img_path)


100%|██████████| 8/8 [00:00<00:00,  9.90it/s]
100%|██████████| 8/8 [00:00<00:00, 52.88it/s]
100%|██████████| 8/8 [00:00<00:00, 55.96it/s]
100%|██████████| 8/8 [00:00<00:00, 54.93it/s]
100%|██████████| 8/8 [00:00<00:00, 57.10it/s]
100%|██████████| 8/8 [00:00<00:00, 48.38it/s]
100%|██████████| 8/8 [00:00<00:00, 47.13it/s]
100%|██████████| 8/8 [00:00<00:00, 52.76it/s]
100%|██████████| 8/8 [00:00<00:00, 45.77it/s]
100%|██████████| 8/8 [00:00<00:00, 54.49it/s]
100%|██████████| 8/8 [00:00<00:00, 52.18it/s]
100%|██████████| 8/8 [00:00<00:00, 51.83it/s]
100%|██████████| 8/8 [00:00<00:00, 52.50it/s]
100%|██████████| 8/8 [00:00<00:00, 55.64it/s]
100%|██████████| 8/8 [00:00<00:00, 51.95it/s]
100%|██████████| 8/8 [00:00<00:00, 57.47it/s]
100%|██████████| 8/8 [00:00<00:00, 54.19it/s]
100%|██████████| 8/8 [00:00<00:00, 57.39it/s]
100%|██████████| 8/8 [00:00<00:00, 53.36it/s]
100%|██████████| 8/8 [00:00<00:00, 57.25it/s]
100%|██████████| 8/8 [00:00<00:00, 50.42it/s]
100%|██████████| 8/8 [00:00<00:00,