# Archive

> archive class.

In [None]:
#| default_exp utils.archive

In [None]:
#| hide
from nbdev.showdoc import *

In [113]:
#| export
import os, pathlib, itertools
from pathlib import Path
from dataclasses import dataclass, field, KW_ONLY
from typing import Optional, List, ClassVar, Any, TypeAlias

In [88]:
#| export
from iza.types import PathLike, PathType
from iza.static import EXT_PY

### Directory Viewer

#### Archive Downloader

- `Directory` defined in `_02_utils/_03_directory.ipynb`
- `ConsoleType` defined in `_02_utils/_03_directory.ipynb`
- `get_console` imported in `_02_utils/_03_directory.ipynb`
- `is_rich_available` defined in `_02_utils/_08_archive.ipynb`
- `urljoin` defined in `_02_utils/_01_files.ipynb`
- `parse_url` imported in `_02_utils/_01_files.ipynb`

In [None]:
#| export
@dataclass
class ArchiveDownloader:    
    _: KW_ONLY
    rootdir: str 
    archive: str
    entries: Union[str, list[str]]
    savedir: str
    extract: bool = False
    cleanup: bool = False
    compound_archive: bool = False
    archives: Optional[list[str]] = None
    console: Optional[ConsoleType] = None
    progress: Optional[ProgressType] = None

    

    def __post_init__(self):        
        self.entries = self.entries if isinstance(self.entries, list) else [self.entries]
        if is_rich_available():
            self.console = get_console()
            self.progress = self.get_progress()

        self.savedir = Path(self.savedir).expanduser()
        make_missing_dirs(self.savedir)

    def get_progress(self):
        if is_rich_available():
            progress = getattr(self, 'progress', None)
            if progress is None and Progress is not None:
                self.progress = Progress(console=self.console)
                return self.progress

        elif Progress is None:
            return None
        
        elif Progress is not None:
            self.progress = Progress(console=self.console)
            return self.progress
        
        else:
            return None

    @property
    def path(self) -> str:
        return urljoin(self.rootdir, self.archive)

    @property
    def urls(self) -> list[str]:
        urls = []
        if self.compound_archive and self.archives is not None:
            for archive, entry in itertools.product(self.archives, self.entries):
                urls.append(urljoin(self.rootdir, archive, entry))
        else:
            urls = [urljoin(self.path, entry) for entry in self.entries]
        return urls

    def download_missing_files(self) -> None:
        total_files = len(self.urls)
        if is_rich_available() and self.progress is not None:
            with self.progress:
                task = self.progress.add_task("[cyan]Downloading...", total=total_files)
                for url in self.urls:
                    filename = Path(parse_url(url).path).name
                    fullpath = self.savedir / filename
                    if not fullpath.exists():
                        stream_file(url, str(fullpath))
                        self.progress.advance(task)
        else:            
            for url in tqdm(self.urls, desc='Downloading'):       
                filename = Path(parse_url(url).path).name
                fullpath = self.savedir / filename
                if not fullpath.exists():
                    stream_file(url, str(fullpath))
                # print(".", end="")

    def extract_files(self) -> None:
        files = [self.savedir / entry for entry in self.entries]
        if is_rich_available() and self.progress is not None:
            with self.progress:
                task = self.progress.add_task("[cyan]Extracting...", total=len(files))
                for file in files:
                    if is_tarball(file):
                        decompress_tarball(file)
                    elif is_gz(file):
                        decompress_gunzip(file, remove=self.cleanup)
                    self.progress.advance(task)
        else:
            for file in tqdm(files, desc='Extracting'):
                if is_tarball(file):
                    decompress_tarball(file)
                elif is_gz(file):
                    decompress_gunzip(file, remove=self.cleanup)

    def execute(self) -> None:
        if is_rich_available():
            self.console.print(f"Processing archive: [bold cyan]{self.archive}[/bold cyan]")
        else:
            print(f"Processing archive: {self.archive}")
        self.download_missing_files()
        if self.extract:
            self.extract_files()

        dir = Directory(self.savedir)
        if is_rich_available():
            dir.print_rich(self.console)
        else:
            dir.print()

@dataclass
class AmazonArchiveDownloader(ArchiveDownloader):
    bucket: str
    region: str = 'us-east-2'

    def __post_init__(self):
        super().__post_init__()
        self.rootdir = f"https://{self.bucket}.s3.{self.region}.amazonaws.com"

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()