Skip to content

Commit

Permalink
Use weighted sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Apr 17, 2024
1 parent 25db9e4 commit 52501d6
Show file tree
Hide file tree
Showing 31 changed files with 227 additions and 254 deletions.
137 changes: 40 additions & 97 deletions generate-subsampling-config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Sample:
group_by: Optional[List[str]]
size: Optional[int]
weight: Optional[int]
min_date: Optional[str]
max_date: Optional[str]
excludes: Optional[List[str]]
Expand All @@ -38,8 +38,8 @@ def to_dict(self):
if self.group_by:
options['group_by'] = self.group_by

if self.size:
options['max_sequences'] = self.size
if self.weight:
options['weight'] = self.weight

if self.min_date:
options['min_date'] = self.min_date
Expand All @@ -57,9 +57,15 @@ def to_dict(self):


class Config:
samples: List[Sample]
def __init__(self):
self.samples = []
size: int
samples: Optional[List[Sample]]

def __init__(self, size, samples=None):
if samples == None:
samples = []

self.size = size
self.samples = samples

def add(self, new_sample: Sample):
if any(new_sample.name == sample.name for sample in self.samples):
Expand All @@ -69,13 +75,14 @@ def add(self, new_sample: Sample):

def to_dict(self):
return {
'size': self.size,
'samples': {
sample.name: sample.to_dict() for sample in self.samples
}
}

def to_file(self, path):
print(f'Writing {path}. n={sum(sample.size for sample in self.samples)}')
print(f'Writing {path}.')
with open(path, 'w') as f:
yaml.dump(self.to_dict(), f, sort_keys=False)

Expand Down Expand Up @@ -124,11 +131,9 @@ def write_region_time_builds():
build_name = f"{region.lower().replace(' ', '-')}_{time.lower()}"
filename = Path(SUBSAMPLING_CONFIG_DIR, f"{build_name}.yaml")

config = Config()

# Global gets special treatment because it is not a region.
if region == 'Global':
target_size = 5150
config = Config(size=5150)

locations = [
'Africa',
Expand Down Expand Up @@ -177,10 +182,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* weights[location] / sum_location_weights
),
weight=weights[location],
excludes=excludes[location],
))
else:
Expand All @@ -193,11 +195,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_EARLY / (WEIGHT_EARLY + WEIGHT_RECENT)
* weights[location] / sum_location_weights
),
weight=(WEIGHT_EARLY * weights[location]),
max_date=time,
excludes=excludes[location],
))
Expand All @@ -210,18 +208,14 @@ def write_region_time_builds():
GROUP_BY_GEOGRAPHICAL_RESOLUTION[location],
GROUP_BY_RECENT_TEMPORAL_RESOLUTION[time],
],
size=int(
target_size
* WEIGHT_RECENT / (WEIGHT_EARLY + WEIGHT_RECENT)
* weights[location] / sum_location_weights
),
weight=(WEIGHT_RECENT * weights[location]),
min_date=time,
excludes=excludes[location],
))

# Asia gets special treatment because two countries must be weighted differently.
elif region == 'Asia':
target_size = 4375
config = Config(size=4375)

locations = [
'Asia',
Expand Down Expand Up @@ -255,11 +249,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
* weights[location] / sum_location_weights
),
weight=(WEIGHT_FOCAL * weights[location]),
excludes=excludes[location],
))

Expand All @@ -271,10 +261,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=WEIGHT_CONTEXTUAL * sum_location_weights,
excludes=['region=Asia'],
))

Expand All @@ -289,12 +276,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_EARLY / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
* weights[location] / sum_location_weights
),
weight=(WEIGHT_EARLY * WEIGHT_FOCAL * weights[location]),
max_date=time,
excludes=excludes[location],
))
Expand All @@ -307,11 +289,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_EARLY / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_EARLY * WEIGHT_CONTEXTUAL * sum_location_weights),
max_date=time,
excludes=['region=Asia'],
))
Expand All @@ -327,12 +305,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_RECENT / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
* weights[location] / sum_location_weights
),
weight=(WEIGHT_RECENT * WEIGHT_FOCAL * weights[location]),
min_date=time,
excludes=excludes[location],
))
Expand All @@ -346,18 +319,14 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_RECENT / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_RECENT * WEIGHT_CONTEXTUAL * sum_location_weights),
min_date=time,
excludes=['region=Asia'],
))

# Everything else is a "standard" region with dynamic geographical/temporal grouping.
else:
target_size = 4000
config = Config(size=4000)

if time == 'all-time':
# Focal sequences for region
Expand All @@ -368,11 +337,8 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
excludes=[f'region!={region}']
weight=WEIGHT_FOCAL,
excludes=[f'region!={region}'],
))

# Contextual sequences from the rest of the world
Expand All @@ -383,11 +349,8 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
excludes=[f'region={region}']
weight=WEIGHT_CONTEXTUAL,
excludes=[f'region={region}'],
))
else:
# Early focal sequences for region
Expand All @@ -398,11 +361,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_EARLY / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_EARLY * WEIGHT_FOCAL),
max_date=time,
excludes=[f'region!={region}'],
))
Expand All @@ -415,11 +374,7 @@ def write_region_time_builds():
'year',
'month',
],
size=int(
target_size
* WEIGHT_EARLY / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_EARLY * WEIGHT_CONTEXTUAL),
max_date=time,
excludes=[f'region={region}'],
))
Expand All @@ -431,11 +386,7 @@ def write_region_time_builds():
GROUP_BY_GEOGRAPHICAL_RESOLUTION[region],
GROUP_BY_RECENT_TEMPORAL_RESOLUTION[time],
],
size=int(
target_size
* WEIGHT_RECENT / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_FOCAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_RECENT * WEIGHT_FOCAL),
min_date=time,
excludes=[f'region!={region}'],
))
Expand All @@ -447,53 +398,45 @@ def write_region_time_builds():
'country',
GROUP_BY_RECENT_TEMPORAL_RESOLUTION[time],
],
size=int(
target_size
* WEIGHT_RECENT / (WEIGHT_EARLY + WEIGHT_RECENT)
* WEIGHT_CONTEXTUAL / (WEIGHT_FOCAL + WEIGHT_CONTEXTUAL)
),
weight=(WEIGHT_RECENT * WEIGHT_CONTEXTUAL),
min_date=time,
excludes=[f'region={region}'],
))

# Double check the total sample size.
total_size = sum(sample.size for sample in config.samples)
assert target_size == total_size

config.to_file(filename)


def write_reference_build():
config = Config()
config = Config(size=300)
config.add(Sample(
name='clades',
group_by=['Nextstrain_clade'],
size=300,
weight=1,
))
filename = Path(SUBSAMPLING_CONFIG_DIR, f"reference.yaml")
config.to_file(filename)


def write_ci_build():
config = Config()
config = Config(size=30)
config.add(Sample(
name='region',
group_by=[
'division',
'year',
'month',
],
size=20,
weight=2,
disable_probabilistic_sampling=True,
excludes=['region!=Europe']
excludes=['region!=Europe'],
))
config.add(Sample(
name='global',
group_by=[
'year',
'month',
],
size=10,
weight=1,
disable_probabilistic_sampling=True,
excludes=['region=Europe'],
# TODO: add Priority(type=proximity, focus=region)
Expand Down
9 changes: 5 additions & 4 deletions subsampling/africa_1m.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
size: 4000
samples:
focal_early:
group_by:
- country
- year
- month
max_sequences: 640
weight: 4
max_date: 1M
exclude:
- region!=Africa
Expand All @@ -13,23 +14,23 @@ samples:
- country
- year
- month
max_sequences: 160
weight: 1
max_date: 1M
exclude:
- region=Africa
focal_recent:
group_by:
- country
- week
max_sequences: 2560
weight: 16
min_date: 1M
exclude:
- region!=Africa
context_recent:
group_by:
- country
- week
max_sequences: 640
weight: 4
min_date: 1M
exclude:
- region=Africa

0 comments on commit 52501d6

Please sign in to comment.