In [None]:
base_dir = "/home/vsl333/datasets/news-bert-data/bertopic/allcsvtopics"
image_dir = "/projects/belongielab/data/frame-align"

import os
import pandas as pd
import requests
from PIL import Image
import numpy as np
import torch
import logging
from tqdm import tqdm
import argparse


# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def download_images(directory_name, id_list, image_url_list):
    for i in tqdm(range(len(id_list))):
        id = id_list[i]
        image_url = image_url_list[i]

        # download image using url and save to image_dir/directory_path. Create directory if it doesn't exist
        os.makedirs(os.path.join(image_dir, directory_name), exist_ok=True)
        # image_path = os.path.join(image_dir, directory_name, f"{id}.jpg")
        # download image url to image_path. Add timeout to prevent hanging
        try:
            response = requests.get(image_url, stream=True, timeout=20)  # Add a timeout (in seconds)
            response.raise_for_status()  # Raise an HTTPError if the status is not 200
            raw_image = Image.open(response.raw).convert("RGB")
            
            # Check the shape of the image tensor
            image_tensor = torch.tensor(np.array(raw_image))
            if image_tensor.shape[-1] != 3:
                raise ValueError(f"Unexpected image shape: {image_tensor.shape}")
            if image_tensor.shape[0] == 1 and image_tensor.shape[1] == 1:
                logger.info(f"Skipping image with shape {image_tensor.shape} - id: {id}")
                continue

        except Exception as e:
            logger.info(f"Image URL: {image_url}")
            logger.error(f"Image error {e} - id: {id}")
            continue

        image_path = os.path.join(image_dir, directory_name, f"{id}.jpg")
        raw_image.save(image_path)
        logger.info(f"Saved image to {image_path}")

    return

def filter_urls(base_dir, directory_name):
    directory_path = os.path.join(base_dir, directory_name)
    csv_file = "datawithtopics_merged.csv"
    df = pd.read_csv(os.path.join(directory_path, csv_file))

    # get 'id' and 'image_url' columns and drop rows with missing image_url
    df_image = df[['id', 'image_url']].dropna(subset=['image_url']) 

    df_image = df_image[0:10]

    id_list = df_image['id'].tolist()
    image_url_list = df_image['image_url'].tolist()

    return df, id_list, image_url_list


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Download images from URLs.")
    # for base_dir too
    parser.add_argument("base_dir", type=str, help="The base directory to save images.")
    parser.add_argument("directory_name", type=str, help="The directory name to save images.")
    base_dir = "/home/vsl333/datasets/news-bert-data/bertopic/allcsvtopics"
    args = parser.parse_args()

    data_df, id_list, image_url_list = filter_urls(args.directory_name)
    download_images(args.directory_name, id_list, image_url_list)

In [14]:
# give list of director names in the base_dir, not path, just names
dir_list = []
for directory_name in os.listdir(base_dir):
    if os.path.isdir(os.path.join(base_dir, directory_name)):
        dir_list.append(directory_name)

print(dir_list)


['2023-11-01_2023-11-30', '2023-06-01_2023-06-30', '2024-01-01_2024-01-31', '2024-03-01_2024-03-31', '2023-08-01_2023-08-31', '2024-02-01_2024-02-29', '2024-04-01_2024-04-30', '2023-05-01_2023-05-31', '2023-10-01_2023-10-31', '2023-12-01_2023-12-31', '2023-07-01_2023-07-31', '2023-09-01_2023-09-30']
