Skip to content

Commit

Permalink
Merge pull request #71 from craigthomas/version-1-test-coverage
Browse files Browse the repository at this point in the history
Add unit tests for Version 1.0.0 release
  • Loading branch information
craigthomas committed Sep 3, 2022
2 parents d614b9d + 02031af commit 861e818
Show file tree
Hide file tree
Showing 26 changed files with 1,959 additions and 839 deletions.
66 changes: 49 additions & 17 deletions assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# I M P O R T S ###############################################################

import argparse
import sys

from cocoasm.exceptions import TranslationError, ParseError
from cocoasm.program import Program
from cocoasm.virtualfiles.virtualfile import CoCoFile
from cocoasm.virtualfiles.binary import BinaryFile
from cocoasm.virtualfiles.cassette import CassetteFile
from cocoasm.virtualfiles.virtual_file import VirtualFileType, VirtualFile
from cocoasm.virtualfiles.source_file import SourceFile, SourceFileType
from cocoasm.virtualfiles.coco_file import CoCoFile

# F U N C T I O N S ###########################################################

Expand Down Expand Up @@ -56,14 +58,34 @@ def parse_arguments():
return parser.parse_args()


def throw_error(error):
"""
Prints out an error message.
:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement)))
sys.exit(1)


def main(args):
"""
Runs the assembler with the specified arguments.
:param args: the command-line arguments
"""
program = Program(width=args.width)
program.process(args.filename)
source_file = SourceFile(args.filename)
source_file.read_file()
program = Program()

try:
program.process(source_file.get_buffer())
except TranslationError as error:
throw_error(error)
except ParseError as error:
throw_error(error)

coco_file = CoCoFile(
name=program.name or args.name,
load_addr=program.origin,
Expand All @@ -72,18 +94,25 @@ def main(args):
)

if args.symbols:
program.print_symbol_table()
print("-- Symbol Table --")
for symbol in program.get_symbol_table():
print(symbol)

if args.print:
program.print_statements()
print("-- Assembled Statements --")
for statement in program.get_statements():
print(statement)

if args.bin_file:
try:
binary_file = BinaryFile()
binary_file.open_host_file_for_write(args.bin_file, append=args.append)
binary_file.save_to_host_file(coco_file)
binary_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.bin_file, file_type=SourceFileType.BINARY),
VirtualFileType.BINARY
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save binary file:")
print(error)

Expand All @@ -92,11 +121,14 @@ def main(args):
print("No name for the program specified, not creating cassette file")
return
try:
cas_file = CassetteFile()
cas_file.open_host_file_for_write(args.cas_file, append=args.append)
cas_file.save_to_host_file(coco_file)
cas_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.cas_file, file_type=SourceFileType.BINARY),
VirtualFileType.CASSETTE
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save cassette file:")
print(error)

Expand Down
24 changes: 11 additions & 13 deletions cocoasm/operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def translate(self):
if registers[1] not in REGISTERS:
raise OperandTypeError("[{}] unknown register".format(registers[1]))

post_byte |= 0x00 if registers[0] == "D" else 0x00
post_byte |= 0x00 if registers[1] == "D" else 0x00
# Implicit is that a value of 0x00 means register D

post_byte |= 0x10 if registers[0] == "X" else 0x00
post_byte |= 0x01 if registers[1] == "X" else 0x00
Expand Down Expand Up @@ -471,7 +470,7 @@ def translate(self):
)
size = self.instruction.mode.ind_sz

if not type(self.value) == str and self.value.is_address():
if type(self.value) != str and self.value.is_address():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -481,7 +480,7 @@ def translate(self):
max_size=size,
)

if not type(self.value) == str and self.value.is_numeric():
if type(self.value) != str and self.value.is_numeric():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -507,7 +506,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
if "-" in self.right or "+" in self.right:
if self.right == "X+" or self.right == "Y+" or self.right == "U+" or self.right == "S+":
raise OperandTypeError("[{}] not allowed as an extended indirect value".format(self.right))
Expand Down Expand Up @@ -603,13 +602,12 @@ def __init__(self, operand_string, instruction):
self.right = self.value.right

def resolve_symbols(self, symbol_table):
if self.left != "":
if self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
if self.left != "" and self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
return self

def translate(self):
Expand All @@ -634,7 +632,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
raw_post_byte |= 0x80
if "-" in self.right or "+" in self.right:
if "+" in self.right:
Expand Down
154 changes: 62 additions & 92 deletions cocoasm/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
"""
# I M P O R T S ###############################################################

import sys

from cocoasm.exceptions import TranslationError, ParseError, ValueTypeError
from cocoasm.exceptions import TranslationError, ValueTypeError
from cocoasm.statement import Statement
from cocoasm.values import AddressValue, NoneValue
from cocoasm.virtualfiles.source_file import SourceFile

# C L A S S E S ###############################################################

Expand All @@ -21,59 +20,38 @@ class Program(object):
contains a list of statements. Additionally, a Program keeps track of all
the user-defined symbols in the program.
"""
def __init__(self, width=100):
def __init__(self):
self.symbol_table = dict()
self.statements = []
self.address = 0x0
self.origin = NoneValue()
self.name = None
self.width = width

def process(self, filename):
def process(self, source_file):
"""
Processes a filename for assembly.
:param filename: the name of the file to process
:param source_file: the source file to process
"""
try:
self.parse(filename)
self.translate_statements()
except TranslationError as error:
self.throw_error(error)
except ParseError as error:
self.throw_error(error)
self.statements = self.parse(source_file)
self.translate_statements()

def parse(self, filename):
@classmethod
def parse(cls, contents):
"""
Parses a single file and saves the set of statements.
:param filename: the name of the file to process
"""
self.statements = self.parse_file(filename)

def parse_file(self, filename):
"""
Parses all of the lines in a file, and transforms each line into
a Statement. Returns a list of all the statements in the file.
:param filename: the name of the file to parse
:param contents: a list of strings, each string represents one line of assembly
"""
statements = []
if not filename:
return statements

try:
with open(filename) as infile:
for line in infile:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
except ParseError as error:
self.throw_error(error)

for line in contents:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
return statements

def process_mnemonics(self, statements):
@classmethod
def process_mnemonics(cls, statements):
"""
Given a list of statements, processes the mnemonics on each statement, and
assigns each statement an Instruction object. If the statement is the
Expand All @@ -85,8 +63,14 @@ def process_mnemonics(self, statements):
"""
processed_statements = []
for statement in statements:
include = self.process_mnemonics(self.parse_file(statement.get_include_filename()))
processed_statements.extend(include if include else [statement])
include_filename = statement.get_include_filename()
if include_filename:
include_source = SourceFile(include_filename)
include_source.read_file()
include = cls.process_mnemonics(cls.parse(include_source.get_buffer()))
processed_statements.extend(include)
else:
processed_statements.extend([statement])
return processed_statements

def save_symbol(self, index, statement):
Expand All @@ -112,46 +96,40 @@ def translate_statements(self):
Translates all the parsed statements into their respective
opcodes.
"""
try:
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)

for index, statement in enumerate(self.statements):
try:
statement.resolve_symbols(self.symbol_table)
except ValueTypeError as error:
raise TranslationError(str(error), statement)
for index, statement in enumerate(self.statements):
statement.resolve_symbols(self.symbol_table)

for index, statement in enumerate(self.statements):
statement.translate()

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
statement.translate()
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)
address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size

address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size
for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string
except TranslationError as error:
self.throw_error(error)
# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string

def get_binary_array(self):
"""
Expand Down Expand Up @@ -187,30 +165,22 @@ def all_sizes_fixed(self):
return False
return True

def print_symbol_table(self):
def get_symbol_table(self):
"""
Prints out the symbol table and any values contained within it.
Returns a list of strings. Each string contains one entry from the symbol table.
"""
print("-- Symbol Table --")
lines = []
for symbol, value in self.symbol_table.items():
print("${} {}".format(value.hex().ljust(4, ' '), symbol))
lines.append("${} {}".format(value.hex().ljust(4, ' '), symbol))
return lines

def print_statements(self):
def get_statements(self):
"""
Prints out the assembled statements.
Returns a list of strings. Each string represents one assembled statement
"""
print("-- Assembled Statements --")
lines = []
for index, statement in enumerate(self.statements):
print("{}".format(str(statement).ljust(self.width)))

def throw_error(self, error):
"""
Prints out an error message.
:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement).ljust(self.width)))
sys.exit(1)
lines.append("{}".format(str(statement)))
return lines

# E N D O F F I L E #######################################################

0 comments on commit 861e818

Please sign in to comment.